Created
September 25, 2018 14:53
-
-
Save haampie/e7b4c162b0632310716cf1f9da8f0639 to your computer and use it in GitHub Desktop.
session2.jl
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| using LinearAlgebra | |
| using LinearAlgebra: givensAlgorithm | |
| using Test | |
| using BenchmarkTools | |
| import LinearAlgebra: rmul! | |
| abstract type SmallRotation end | |
| struct Rotation2{Tc,Ts} <: SmallRotation | |
| c::Tc | |
| s::Ts | |
| i::Int | |
| end | |
| """ | |
| A set of fused Givens rotations | |
| G3 G4 | |
| G1 G2 | |
| """ | |
| struct Fused2x2{Tc,Ts} <: SmallRotation | |
| c1::Tc | |
| s1::Ts | |
| c2::Tc | |
| s2::Ts | |
| c3::Tc | |
| s3::Ts | |
| c4::Tc | |
| s4::Ts | |
| i::Int | |
| end | |
| @inline rmul!(A::AbstractMatrix, G::SmallRotation) = rmul!(A, G, axes(A, 1)) | |
| @inline function rmul!(A::AbstractMatrix, G::Rotation2, range) | |
| @inbounds for j = range | |
| a₁ = A[j, G.i + 0] | |
| a₂ = A[j, G.i + 1] | |
| a₁′ = muladd( a₁, G.c, a₂ * G.s') | |
| a₂′ = muladd(-a₁, G.s, a₂ * G.c ) | |
| A[j, G.i+0] = a₁′ | |
| A[j, G.i+1] = a₂′ | |
| end | |
| end | |
| @inline function rmul!(A::AbstractMatrix, G::Fused2x2, range) | |
| @inbounds for j = range | |
| a0 = A[j, G.i + 0] | |
| a1 = A[j, G.i + 1] | |
| a2 = A[j, G.i + 2] | |
| a3 = A[j, G.i + 3] | |
| # Apply rotation 1 | |
| a1′ = muladd( a1, G.c1, a2 * G.s1') | |
| a2′ = muladd(-a1, G.s1, a2 * G.c1 ) | |
| # Apply rotation 2 | |
| a2′′ = muladd( a2′, G.c2, a3 * G.s2') | |
| a3′′ = muladd(-a2′, G.s2, a3 * G.c2 ) | |
| # Apply rotation 3 | |
| a0′′′ = muladd( a0, G.c3, a1′ * G.s3') | |
| a1′′′ = muladd(-a0, G.s3, a1′ * G.c3 ) | |
| # Apply rotation 4 | |
| a1′′′′ = muladd( a1′′′, G.c4, a2′′ * G.s4') | |
| a2′′′′ = muladd(-a1′′′, G.s4, a2′′ * G.c4 ) | |
| A[j, G.i + 0] = a0′′′ | |
| A[j, G.i + 1] = a1′′′′ | |
| A[j, G.i + 2] = a2′′′′ | |
| A[j, G.i + 3] = a3′′ | |
| end | |
| A | |
| end | |
| function generate_rotations(cols, ks) | |
| # Generate some random rotations inbetween the columns | |
| rotations = Matrix{Tuple{Float64,Float64}}(undef, ks, cols - 1) | |
| for n = 1 : cols - 1, k = 1 : ks | |
| c, s, = givensAlgorithm(rand(), rand()) | |
| rotations[k, n] = (c, s) | |
| end | |
| rotations | |
| end | |
| function example(cols, ks, rows = 1000) | |
| rotations = generate_rotations(cols, ks) | |
| # Some matrix to apply them to | |
| Q = rand(rows, cols) | |
| Q1 = copy(Q) | |
| Q2 = copy(Q) | |
| # Apply the rotations to Q1 and Q2 | |
| trivial_order!(Q1, rotations) | |
| fused_order!(Q2, rotations) | |
| @test Q1 ≈ Q2 # exact equality! | |
| @test Q1 != Q # lets see if anything happened | |
| end | |
| function bench(ns = 100:100:2000, k = 16) | |
| # Find the number of GFLOPS / second for the trivial vs fused method | |
| GFLOPS_a, GFLOPS_b = Float64[], Float64[] | |
| for n = ns | |
| rotations = generate_rotations(n, k) | |
| # 1 rotation = 6 flops; n - 1 rotations per layer; k layers; applied to n rows | |
| flops = 6k * n * (n - 1) | |
| a = @belapsed trivial_order!(Q, $rotations) setup = (Q = rand($n, $n)) | |
| b = @belapsed fused_order!(Q, $rotations) setup = (Q = rand($n, $n)) | |
| gflops_per_sec_a = flops / a / 1e9 | |
| gflops_per_sec_b = flops / b / 1e9 | |
| push!(GFLOPS_a, gflops_per_sec_a) | |
| push!(GFLOPS_b, gflops_per_sec_b) | |
| @info "Size ($n × $n) with $k layers" gflops_per_sec_a gflops_per_sec_b | |
| end | |
| return GFLOPS_a, GFLOPS_b | |
| end | |
| """ | |
| Apply all the rotations to Q in the trivial order | |
| """ | |
| function trivial_order!(Q, rotations) | |
| @inbounds for k = axes(rotations, 1) | |
| for n = axes(rotations, 2) | |
| c, s = rotations[k, n] | |
| G = Rotation2(c, s, n) | |
| rmul!(Q, G) | |
| end | |
| end | |
| Q | |
| end | |
| """ | |
| Apply all rotations in a fancy order of "fused" groups of 2 by 2. | |
| """ | |
| function fused_order!(Q, rotations) | |
| numrot = size(rotations, 2) | |
| layers = size(rotations, 1) | |
| @inbounds for k = 1 : 2 : layers | |
| # First the left-over rotation | |
| c, s = rotations[k, 1] | |
| rmul!(Q, Rotation2(c, s, 1)) | |
| # Then the fused rotations | |
| for n = 1 : 2 : numrot - 2 | |
| c1, s1 = rotations[k, n + 1] | |
| c2, s2 = rotations[k, n + 2] | |
| c3, s3 = rotations[k + 1, n + 0] | |
| c4, s4 = rotations[k + 1, n + 1] | |
| G = Fused2x2(c1, s1, c2, s2, c3, s3, c4, s4, n) | |
| rmul!(Q, G) | |
| end | |
| # Then the last left-over rotation | |
| c, s = rotations[k + 1, numrot] | |
| rmul!(Q, Rotation2(c, s, numrot)) | |
| end | |
| Q | |
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| 25.443058307008616 33.9416449412521 | |
| 24.035781911514 34.74023021949046 | |
| 21.91412691626458 33.827249728948324 | |
| 22.428363354241842 35.74526459698624 | |
| 21.653893894835054 31.788585731766553 | |
| 22.837777012894172 36.57162239511564 | |
| 22.827950593092833 36.24360547209555 | |
| 22.129723327969486 35.587929964564715 | |
| 20.75834604779891 34.19481911087906 | |
| 17.53682763059729 31.306708398498515 | |
| 14.557958485864136 28.39682799515229 | |
| 13.268273789557043 26.586767979803074 | |
| 12.371269965403373 25.148035286474574 | |
| 11.87160186149154 24.500843730657717 | |
| 11.409941485568504 24.100749385746642 | |
| 11.193609162623634 23.340339370714855 | |
| 10.977945588068426 22.902733624125194 | |
| 10.829642308188332 22.62420426666628 | |
| 10.698704894235814 22.45534911736149 | |
| 10.567489622366047 22.333505961490342 | |
| 10.453453227191035 22.062975235473203 | |
| 10.304598945317368 21.901561040237002 | |
| 10.250387131331692 20.919456093284385 | |
| 10.221043800605974 21.677911489203353 | |
| 10.244324474117501 21.561385706572892 | |
| 10.22107224529424 21.321814282119465 | |
| 10.08931160681687 21.324135294927196 | |
| 10.046588497815552 21.410853988265696 | |
| 10.049439113087082 21.06104910371202 | |
| 10.047482857832954 21.031550466874695 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment