Last active
December 4, 2018 04:40
-
-
Save tkf/926802a335ec491784192d8566a42cf1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module SparseDots | |
import SparseArrays | |
using Random | |
using SIMD | |
struct SpVect{TVi,TVv} | |
n::Int | |
nzind::TVi | |
nzval::TVv | |
end | |
SparseArrays.SparseVector(sv::SpVect) = | |
SparseArrays.SparseVector(sv.n, Vector(sv.nzind), Vector(sv.nzval)) | |
function rand_spvect(rng::AbstractRNG, n, p=0.1; N=4, align=true) | |
nzind = randsubseq(rng, 1:n, p) | |
nz = length(nzind) | |
nzval = randn(rng, nz) ./ √nz | |
if !align | |
return SpVect(n, nzind, nzval) | |
end | |
nzind_ = valloc(eltype(nzind), N, nz) | |
fill!(nzind_.parent, 0) | |
nzind_ .= nzind | |
nzval_ = valloc(eltype(nzval), N, nz) | |
fill!(nzval_.parent, 0) | |
nzval_ .= nzval | |
return SpVect(n, nzind_, nzval_) | |
end | |
@inline function dot_simple(sv::SpVect, ys::AbstractVector, | |
::Any, # to ignore Val(align) | |
) | |
@assert sv.n == length(ys) | |
T = promote_type(eltype(sv.nzval), eltype(ys)) | |
acc = zero(T) | |
@inbounds for j in 1:length(sv.nzval) | |
acc = muladd(sv.nzval[j], ys[sv.nzind[j]], acc) | |
end | |
return acc | |
end | |
@inline function dot_simd(sv::SpVect, ys::AbstractVector, | |
::Val{align} = Val(false), | |
::Val{N} = Val(4), | |
) where {N, align} | |
Ti = eltype(sv.nzind) | |
Tv = eltype(sv.nzval) | |
T = promote_type(eltype(sv.nzval), eltype(ys)) | |
isempty(sv.nzval) && return zero(T) | |
nz_size = length(sv.nzval) | |
nomask = Vec(ntuple(_ -> true, N)) | |
vacc = zero(Vec{N, T}) | |
@inbounds for j in 1:N:(nz_size - N + 1) | |
idx = vload(Vec{N, Ti}, sv.nzind, j, Val{align}) | |
bs = vgather(ys, idx, nomask, Val{align}) | |
as = vload(Vec{N, Tv}, sv.nzval, j, Val{align}) | |
vacc = muladd(as, bs, vacc) | |
end | |
acc = sum(vacc) | |
@inbounds for j in (nz_size - nz_size % N + 1):nz_size | |
acc = muladd(sv.nzval[j], ys[sv.nzind[j]], acc) | |
end | |
return acc | |
end | |
@inline function dot_ptr(sv::SpVect, ys::AbstractVector, | |
::Val{align} = Val(false), | |
::Val{N} = Val(4), | |
) where {N, align} | |
@assert sv.n == length(ys) | |
Ti = eltype(sv.nzind) | |
Tv = eltype(sv.nzval) | |
Ty = eltype(ys) | |
T = promote_type(eltype(sv.nzval), eltype(ys)) | |
isempty(sv.nzval) && return zero(T) | |
iptr_end = pointer(@inbounds @view sv.nzind[end]) | |
iptr_simd_end = iptr_end - sizeof(Ti) * (N - 1) | |
iptr = pointer(sv.nzind) | |
vptr = pointer(sv.nzval) | |
yptr_ofs = pointer(ys) - sizeof(Ty) | |
nomask = Vec(ntuple(_ -> true, N)) | |
vacc = zero(Vec{N, T}) | |
@inbounds while iptr <= iptr_simd_end | |
as = vload(Vec{N, Tv}, vptr, Val{align}) | |
idx = vload(Vec{N, Ti}, iptr, Val{align}) | |
bs = vgather( | |
Vec{N, Ty}, yptr_ofs + sizeof(Ty) * idx, | |
nomask, Val{align}) | |
vacc = muladd(as, bs, vacc) | |
vptr += sizeof(Tv) * N | |
iptr += sizeof(Ti) * N | |
end | |
nz_size = length(sv.nzval) | |
acc = sum(vacc) | |
@inbounds for j in (nz_size - nz_size % N + 1):nz_size | |
acc = muladd(sv.nzval[j], ys[sv.nzind[j]], acc) | |
end | |
return acc | |
end | |
end # module | |
using Test | |
using Random | |
using LinearAlgebra | |
using SparseArrays | |
using Statistics | |
using SIMD | |
using BenchmarkTools | |
dot_simple = SparseDots.dot_simple | |
dot_simd = SparseDots.dot_simd | |
dot_ptr = SparseDots.dot_ptr | |
rand_spvect = SparseDots.rand_spvect | |
function makedata(avgdeg; p = 0.1, N = 4, align = true, | |
rng = Random.GLOBAL_RNG) | |
rng = rng isa AbstractRNG ? rng : MersenneTwister(rng) | |
m = ceil(Int, avgdeg / p) | |
sv = rand_spvect(rng, m, p; N = N, align = align) | |
if align | |
ys = valloc(eltype(sv.nzval), N, m) | |
randn!(rng, ys) | |
else | |
ys = randn(rng, m) | |
end | |
return (sv = sv, ys = ys) | |
end | |
@testset begin | |
@testset for seed in 1:10, align in [false, true] | |
(sv, ys) = makedata(5; rng=(seed * (align + 1)), align=align) | |
z0 = dot(SparseVector(sv), ys) | |
z1 = dot_simple(sv, ys, Val(align)) | |
z2 = dot_simd(sv, ys, Val(align)) | |
z3 = dot_ptr(sv, ys, Val(align)) | |
@test z1 ≈ z0 | |
@test z2 ≈ z0 | |
@test z3 ≈ z0 | |
end | |
end | |
suite = BenchmarkGroup() | |
sizelist = 2 .^ (4:2:14) | |
alignments = [false, true] | |
for avgdeg in sizelist | |
for dotfun in [dot_simple, dot_simd, dot_ptr], align in alignments | |
suite[nameof(dotfun), align, avgdeg] = @benchmarkable begin | |
sv, ys = data | |
$dotfun(sv, ys, $(Val(align))) | |
end setup=(data = makedata($avgdeg)) | |
end | |
end | |
tune!(suite, verbose = true) | |
println() | |
results = run(suite, verbose = true) | |
results |> display | |
println() | |
times(results, sizelist, dotfun, align, f=median) = | |
[f(results[dotfun, align, n]).time for n in sizelist] | |
println("dot_simd / dot_simple") | |
println(times(results, sizelist, :dot_simd, alignments[1]) ./ | |
times(results, sizelist, :dot_simple, alignments[1])) | |
println("dot_ptr / dot_simple") | |
println(times(results, sizelist, :dot_ptr, alignments[1]) ./ | |
times(results, sizelist, :dot_simple, alignments[1])) | |
using Plots | |
function plot_times(results, sizelist, | |
alignments = alignments, | |
functions = [:dot_simple, :dot_simd, :dot_ptr]) | |
plt = plot() | |
for dotfun in functions, align in alignments | |
plot!(plt, sizelist, | |
times(results, sizelist, dotfun, align), | |
marker = :o, | |
label = string(dotfun, align ? " aligned" : "")) | |
end | |
plot!(plt, | |
ylabel = "Time [ns]", | |
xlabel = "Average number non-zero elements", | |
xscale = :log10, | |
yscale = :log10, | |
legend = :topleft) | |
return plt | |
end | |
function plot_ratio(results, sizelist, | |
alignments = alignments, | |
functions = [:dot_simple, :dot_simd, :dot_ptr]) | |
plt = plot() | |
color = length(alignments) | |
for dotfun in functions[2:end] | |
for align in alignments | |
baseline = times(results, sizelist, functions[1], align) | |
plot!(plt, sizelist, | |
times(results, sizelist, dotfun, align) ./ baseline, | |
color = (color += 1), | |
marker = :o, | |
label = string("$dotfun/$(functions[1])", | |
align ? " aligned" : "")) | |
end | |
end | |
plot!(plt, [1], seriestype=:hline, color=:black, label="") | |
plot!(plt, | |
ylabel = "Ratio to baseline", | |
xlabel = "Average number non-zero elements", | |
xscale = :log10, | |
legend = :top) | |
return plt | |
end | |
function plot_all(results, sizelist; kwargs...) | |
plot( | |
plot_times(results, sizelist), | |
plot_ratio(results, sizelist); | |
layout = (2, 1), | |
kwargs...) | |
end | |
if abspath(PROGRAM_FILE) == @__FILE__ | |
pyplot() # some lines vanishes with GR | |
plt = plot_all(results, sizelist, size=(600, 500)) | |
path = "benchmark.png" | |
@info "Saving plot to $path" | |
savefig(plt, path) | |
else | |
plt = plot_all(results, sizelist) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment