Created
May 10, 2020 19:10
-
-
Save haampie/26d5458aaa4efd0e47a62fb723a06a2d to your computer and use it in GitHub Desktop.
givens.jl
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using SIMD | |
using BenchmarkTools | |
using Test | |
struct Rot{T} | |
c::T | |
s::T | |
end | |
mul(G::Rot, a, b) = (G.c * a + G.s * b, -G.s * a + G.c * b) | |
function reference_impl!(A::AbstractMatrix{Float64}, givens::AbstractArray{Float64}) | |
@inbounds begin | |
G1 = Rot(givens[1], givens[2]) | |
G2 = Rot(givens[3], givens[4]) | |
G3 = Rot(givens[5], givens[6]) | |
G4 = Rot(givens[7], givens[8]) | |
for row in axes(A, 1) | |
A1 = A[row, 1] | |
A2 = A[row, 2] | |
A3 = A[row, 3] | |
A4 = A[row, 4] | |
# Apply rotation 1 to column 2 and 3 | |
A2′, A3′ = mul(G1, A2, A3) | |
# Apply rotation 2 to column 3 and 4 | |
A3′′, A4′ = mul(G2, A3′, A4) | |
# Apply rotation 3 to column 1 and 2 | |
A1′, A2′′ = mul(G3, A1, A2′) | |
# Apply rotation 4 to column 2 and 3 | |
A2′′′, A3′′′ = mul(G4, A2′′, A3′′) | |
A[row, 1] = A1′ | |
A[row, 2] = A2′′′ | |
A[row, 3] = A3′′′ | |
A[row, 4] = A4′ | |
end | |
end | |
return A | |
end | |
function avx_givens_first!(A::AbstractMatrix{Float64}, givens::AbstractArray{Float64}) | |
A_col_1 = pointer(A) + 0 * stride(A, 2) * sizeof(Float64) | |
A_col_2 = pointer(A) + 1 * stride(A, 2) * sizeof(Float64) | |
A_col_3 = pointer(A) + 2 * stride(A, 2) * sizeof(Float64) | |
A_col_4 = pointer(A) + 3 * stride(A, 2) * sizeof(Float64) | |
@inbounds begin | |
# Load the rotations | |
c1 = vgather(givens, Vec((1, 1, 1, 1))) | |
s1 = vgather(givens, Vec((2, 2, 2, 2))) | |
c2 = vgather(givens, Vec((3, 3, 3, 3))) | |
s2 = vgather(givens, Vec((4, 4, 4, 4))) | |
c3 = vgather(givens, Vec((5, 5, 5, 5))) | |
s3 = vgather(givens, Vec((6, 6, 6, 6))) | |
c4 = vgather(givens, Vec((7, 7, 7, 7))) | |
s4 = vgather(givens, Vec((8, 8, 8, 8))) | |
for i = Base.OneTo(size(A, 1) ÷ 4) | |
# Load the columns | |
col_1 = vload(Vec{4,Float64}, A_col_1) | |
col_2 = vload(Vec{4,Float64}, A_col_2) | |
col_3 = vload(Vec{4,Float64}, A_col_3) | |
col_4 = vload(Vec{4,Float64}, A_col_4) | |
# Apply rotation 1 to column 2 and 3 | |
col_2′ = s1 * col_3 + c1 * col_2 | |
col_3′ = c1 * col_3 - s1 * col_2 | |
# Apply rotation 2 to column 3 and 4 | |
col_3′′ = s2 * col_4 + c2 * col_3′ | |
col_4′ = c2 * col_4 - s2 * col_3′ | |
# Apply rotation 3 to column 1 and 2 | |
col_1′ = s3 * col_2′ + c3 * col_1 | |
col_2′′ = c3 * col_2′ - s3 * col_1 | |
# Apply rotation 4 to column 2 and 3 | |
col_2′′′ = s4 * col_3′′ + c4 * col_2′′ | |
col_3′′′ = c4 * col_3′′ - s4 * col_2′′ | |
vstore(col_1′ , A_col_1) | |
vstore(col_2′′′, A_col_2) | |
vstore(col_3′′′, A_col_3) | |
vstore(col_4′ , A_col_4) | |
A_col_1 += 4 * sizeof(Float64) | |
A_col_2 += 4 * sizeof(Float64) | |
A_col_3 += 4 * sizeof(Float64) | |
A_col_4 += 4 * sizeof(Float64) | |
end | |
return A | |
end | |
end | |
function avx_givens_second!(A::AbstractMatrix{Float64}, givens::AbstractArray{Float64}) | |
A_col_1 = pointer(A) + 0 * stride(A, 2) * sizeof(Float64) | |
A_col_2 = pointer(A) + 1 * stride(A, 2) * sizeof(Float64) | |
A_col_3 = pointer(A) + 2 * stride(A, 2) * sizeof(Float64) | |
A_col_4 = pointer(A) + 3 * stride(A, 2) * sizeof(Float64) | |
@inbounds begin | |
# Load the rotations | |
c1 = vgather(givens, Vec((1, 1, 1, 1))) | |
s1 = vgather(givens, Vec((2, 2, 2, 2))) | |
c2 = vgather(givens, Vec((3, 3, 3, 3))) | |
s2 = vgather(givens, Vec((4, 4, 4, 4))) | |
c3 = vgather(givens, Vec((5, 5, 5, 5))) | |
s3 = vgather(givens, Vec((6, 6, 6, 6))) | |
c4 = vgather(givens, Vec((7, 7, 7, 7))) | |
s4 = vgather(givens, Vec((8, 8, 8, 8))) | |
for i = Base.OneTo(size(A, 1) ÷ 8) | |
# Load the columns | |
col_1 = vload(Vec{4,Float64}, A_col_1) | |
col_5 = vload(Vec{4,Float64}, A_col_1 + 4 * sizeof(Float64)) | |
col_2 = vload(Vec{4,Float64}, A_col_2) | |
col_6 = vload(Vec{4,Float64}, A_col_2 + 4 * sizeof(Float64)) | |
col_3 = vload(Vec{4,Float64}, A_col_3) | |
col_7 = vload(Vec{4,Float64}, A_col_3 + 4 * sizeof(Float64)) | |
col_4 = vload(Vec{4,Float64}, A_col_4) | |
col_8 = vload(Vec{4,Float64}, A_col_4 + 4 * sizeof(Float64)) | |
# Appy the Z-shape rotations | |
# Apply rotation 1 to column 2 and 3 | |
col_2′ = s1 * col_3 + c1 * col_2 | |
col_3′ = c1 * col_3 - s1 * col_2 | |
col_6′ = s1 * col_7 + c1 * col_6 | |
col_7′ = c1 * col_7 - s1 * col_6 | |
# Apply rotation 2 to column 3 and 4 | |
col_3′′ = s2 * col_4 + c2 * col_3′ | |
col_4′ = c2 * col_4 - s2 * col_3′ | |
col_7′′ = s2 * col_8 + c2 * col_7′ | |
col_8′ = c2 * col_8 - s2 * col_7′ | |
# Apply rotation 3 to column 1 and 2 | |
col_1′ = s3 * col_2′ + c3 * col_1 | |
col_2′′ = c3 * col_2′ - s3 * col_1 | |
col_5′ = s3 * col_6′ + c3 * col_5 | |
col_6′′ = c3 * col_6′ - s3 * col_5 | |
# Apply rotation 4 to column 2 and 3 | |
col_2′′′ = s4 * col_3′′ + c4 * col_2′′ | |
col_3′′′ = c4 * col_3′′ - s4 * col_2′′ | |
col_6′′′ = s4 * col_7′′ + c4 * col_6′′ | |
col_7′′′ = c4 * col_7′′ - s4 * col_6′′ | |
vstore(col_1′ , A_col_1) | |
vstore(col_5′ , A_col_1 + 4 * sizeof(Float64)) | |
vstore(col_2′′′, A_col_2) | |
vstore(col_6′′′, A_col_2 + 4 * sizeof(Float64)) | |
vstore(col_3′′′, A_col_3) | |
vstore(col_7′′′, A_col_3 + 4 * sizeof(Float64)) | |
vstore(col_4′ , A_col_4) | |
vstore(col_8′ , A_col_4 + 4 * sizeof(Float64)) | |
A_col_1 += 8 * sizeof(Float64) | |
A_col_2 += 8 * sizeof(Float64) | |
A_col_3 += 8 * sizeof(Float64) | |
A_col_4 += 8 * sizeof(Float64) | |
end | |
return A | |
end | |
end | |
function bench() | |
A = rand(Float64, 8 * 1000, 4) | |
givens = rand(Float64, 8) | |
ref = @benchmark reference_impl!($A, $givens) | |
avx_first = @benchmark avx_givens_first!($A, $givens) | |
avx_second = @benchmark avx_givens_second!($A, $givens) | |
ref, avx_first, avx_second | |
end | |
function test() | |
A = rand(Float64, 8 * 1000, 4) | |
givens = rand(Float64, 8) | |
A′ = reference_impl!(copy(A), givens) | |
B′ = avx_givens_first!(copy(A), givens) | |
C′ = avx_givens_second!(copy(A), givens) | |
@test(A′ ≈ B′), @test(A′ ≈ C′) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment