Skip to content

Instantly share code, notes, and snippets.

@haampie
Created September 25, 2018 14:53
Show Gist options
  • Select an option

  • Save haampie/e7b4c162b0632310716cf1f9da8f0639 to your computer and use it in GitHub Desktop.

Select an option

Save haampie/e7b4c162b0632310716cf1f9da8f0639 to your computer and use it in GitHub Desktop.
session2.jl
using LinearAlgebra
using LinearAlgebra: givensAlgorithm
using Test
using BenchmarkTools
import LinearAlgebra: rmul!
abstract type SmallRotation end
struct Rotation2{Tc,Ts} <: SmallRotation
c::Tc
s::Ts
i::Int
end
"""
A set of fused Givens rotations
G3 G4
G1 G2
"""
struct Fused2x2{Tc,Ts} <: SmallRotation
c1::Tc
s1::Ts
c2::Tc
s2::Ts
c3::Tc
s3::Ts
c4::Tc
s4::Ts
i::Int
end
@inline rmul!(A::AbstractMatrix, G::SmallRotation) = rmul!(A, G, axes(A, 1))
@inline function rmul!(A::AbstractMatrix, G::Rotation2, range)
@inbounds for j = range
a₁ = A[j, G.i + 0]
a₂ = A[j, G.i + 1]
a₁′ = muladd( a₁, G.c, a₂ * G.s')
a₂′ = muladd(-a₁, G.s, a₂ * G.c )
A[j, G.i+0] = a₁′
A[j, G.i+1] = a₂′
end
end
@inline function rmul!(A::AbstractMatrix, G::Fused2x2, range)
@inbounds for j = range
a0 = A[j, G.i + 0]
a1 = A[j, G.i + 1]
a2 = A[j, G.i + 2]
a3 = A[j, G.i + 3]
# Apply rotation 1
a1′ = muladd( a1, G.c1, a2 * G.s1')
a2′ = muladd(-a1, G.s1, a2 * G.c1 )
# Apply rotation 2
a2′′ = muladd( a2′, G.c2, a3 * G.s2')
a3′′ = muladd(-a2′, G.s2, a3 * G.c2 )
# Apply rotation 3
a0′′′ = muladd( a0, G.c3, a1′ * G.s3')
a1′′′ = muladd(-a0, G.s3, a1′ * G.c3 )
# Apply rotation 4
a1′′′′ = muladd( a1′′′, G.c4, a2′′ * G.s4')
a2′′′′ = muladd(-a1′′′, G.s4, a2′′ * G.c4 )
A[j, G.i + 0] = a0′′′
A[j, G.i + 1] = a1′′′′
A[j, G.i + 2] = a2′′′′
A[j, G.i + 3] = a3′′
end
A
end
function generate_rotations(cols, ks)
# Generate some random rotations inbetween the columns
rotations = Matrix{Tuple{Float64,Float64}}(undef, ks, cols - 1)
for n = 1 : cols - 1, k = 1 : ks
c, s, = givensAlgorithm(rand(), rand())
rotations[k, n] = (c, s)
end
rotations
end
function example(cols, ks, rows = 1000)
rotations = generate_rotations(cols, ks)
# Some matrix to apply them to
Q = rand(rows, cols)
Q1 = copy(Q)
Q2 = copy(Q)
# Apply the rotations to Q1 and Q2
trivial_order!(Q1, rotations)
fused_order!(Q2, rotations)
@test Q1 ≈ Q2 # exact equality!
@test Q1 != Q # lets see if anything happened
end
function bench(ns = 100:100:2000, k = 16)
# Find the number of GFLOPS / second for the trivial vs fused method
GFLOPS_a, GFLOPS_b = Float64[], Float64[]
for n = ns
rotations = generate_rotations(n, k)
# 1 rotation = 6 flops; n - 1 rotations per layer; k layers; applied to n rows
flops = 6k * n * (n - 1)
a = @belapsed trivial_order!(Q, $rotations) setup = (Q = rand($n, $n))
b = @belapsed fused_order!(Q, $rotations) setup = (Q = rand($n, $n))
gflops_per_sec_a = flops / a / 1e9
gflops_per_sec_b = flops / b / 1e9
push!(GFLOPS_a, gflops_per_sec_a)
push!(GFLOPS_b, gflops_per_sec_b)
@info "Size ($n × $n) with $k layers" gflops_per_sec_a gflops_per_sec_b
end
return GFLOPS_a, GFLOPS_b
end
"""
Apply all the rotations to Q in the trivial order
"""
function trivial_order!(Q, rotations)
@inbounds for k = axes(rotations, 1)
for n = axes(rotations, 2)
c, s = rotations[k, n]
G = Rotation2(c, s, n)
rmul!(Q, G)
end
end
Q
end
"""
Apply all rotations in a fancy order of "fused" groups of 2 by 2.
"""
function fused_order!(Q, rotations)
numrot = size(rotations, 2)
layers = size(rotations, 1)
@inbounds for k = 1 : 2 : layers
# First the left-over rotation
c, s = rotations[k, 1]
rmul!(Q, Rotation2(c, s, 1))
# Then the fused rotations
for n = 1 : 2 : numrot - 2
c1, s1 = rotations[k, n + 1]
c2, s2 = rotations[k, n + 2]
c3, s3 = rotations[k + 1, n + 0]
c4, s4 = rotations[k + 1, n + 1]
G = Fused2x2(c1, s1, c2, s2, c3, s3, c4, s4, n)
rmul!(Q, G)
end
# Then the last left-over rotation
c, s = rotations[k + 1, numrot]
rmul!(Q, Rotation2(c, s, numrot))
end
Q
end
25.443058307008616 33.9416449412521
24.035781911514 34.74023021949046
21.91412691626458 33.827249728948324
22.428363354241842 35.74526459698624
21.653893894835054 31.788585731766553
22.837777012894172 36.57162239511564
22.827950593092833 36.24360547209555
22.129723327969486 35.587929964564715
20.75834604779891 34.19481911087906
17.53682763059729 31.306708398498515
14.557958485864136 28.39682799515229
13.268273789557043 26.586767979803074
12.371269965403373 25.148035286474574
11.87160186149154 24.500843730657717
11.409941485568504 24.100749385746642
11.193609162623634 23.340339370714855
10.977945588068426 22.902733624125194
10.829642308188332 22.62420426666628
10.698704894235814 22.45534911736149
10.567489622366047 22.333505961490342
10.453453227191035 22.062975235473203
10.304598945317368 21.901561040237002
10.250387131331692 20.919456093284385
10.221043800605974 21.677911489203353
10.244324474117501 21.561385706572892
10.22107224529424 21.321814282119465
10.08931160681687 21.324135294927196
10.046588497815552 21.410853988265696
10.049439113087082 21.06104910371202
10.047482857832954 21.031550466874695
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment