Skip to content

Instantly share code, notes, and snippets.

@tkf
Last active December 4, 2018 04:40
Show Gist options
  • Save tkf/926802a335ec491784192d8566a42cf1 to your computer and use it in GitHub Desktop.
Save tkf/926802a335ec491784192d8566a42cf1 to your computer and use it in GitHub Desktop.
module SparseDots
import SparseArrays
using Random
using SIMD
struct SpVect{TVi,TVv}
n::Int
nzind::TVi
nzval::TVv
end
SparseArrays.SparseVector(sv::SpVect) =
SparseArrays.SparseVector(sv.n, Vector(sv.nzind), Vector(sv.nzval))
function rand_spvect(rng::AbstractRNG, n, p=0.1; N=4, align=true)
nzind = randsubseq(rng, 1:n, p)
nz = length(nzind)
nzval = randn(rng, nz) ./ √nz
if !align
return SpVect(n, nzind, nzval)
end
nzind_ = valloc(eltype(nzind), N, nz)
fill!(nzind_.parent, 0)
nzind_ .= nzind
nzval_ = valloc(eltype(nzval), N, nz)
fill!(nzval_.parent, 0)
nzval_ .= nzval
return SpVect(n, nzind_, nzval_)
end
@inline function dot_simple(sv::SpVect, ys::AbstractVector,
::Any, # to ignore Val(align)
)
@assert sv.n == length(ys)
T = promote_type(eltype(sv.nzval), eltype(ys))
acc = zero(T)
@inbounds for j in 1:length(sv.nzval)
acc = muladd(sv.nzval[j], ys[sv.nzind[j]], acc)
end
return acc
end
@inline function dot_simd(sv::SpVect, ys::AbstractVector,
::Val{align} = Val(false),
::Val{N} = Val(4),
) where {N, align}
Ti = eltype(sv.nzind)
Tv = eltype(sv.nzval)
T = promote_type(eltype(sv.nzval), eltype(ys))
isempty(sv.nzval) && return zero(T)
nz_size = length(sv.nzval)
nomask = Vec(ntuple(_ -> true, N))
vacc = zero(Vec{N, T})
@inbounds for j in 1:N:(nz_size - N + 1)
idx = vload(Vec{N, Ti}, sv.nzind, j, Val{align})
bs = vgather(ys, idx, nomask, Val{align})
as = vload(Vec{N, Tv}, sv.nzval, j, Val{align})
vacc = muladd(as, bs, vacc)
end
acc = sum(vacc)
@inbounds for j in (nz_size - nz_size % N + 1):nz_size
acc = muladd(sv.nzval[j], ys[sv.nzind[j]], acc)
end
return acc
end
@inline function dot_ptr(sv::SpVect, ys::AbstractVector,
::Val{align} = Val(false),
::Val{N} = Val(4),
) where {N, align}
@assert sv.n == length(ys)
Ti = eltype(sv.nzind)
Tv = eltype(sv.nzval)
Ty = eltype(ys)
T = promote_type(eltype(sv.nzval), eltype(ys))
isempty(sv.nzval) && return zero(T)
iptr_end = pointer(@inbounds @view sv.nzind[end])
iptr_simd_end = iptr_end - sizeof(Ti) * (N - 1)
iptr = pointer(sv.nzind)
vptr = pointer(sv.nzval)
yptr_ofs = pointer(ys) - sizeof(Ty)
nomask = Vec(ntuple(_ -> true, N))
vacc = zero(Vec{N, T})
@inbounds while iptr <= iptr_simd_end
as = vload(Vec{N, Tv}, vptr, Val{align})
idx = vload(Vec{N, Ti}, iptr, Val{align})
bs = vgather(
Vec{N, Ty}, yptr_ofs + sizeof(Ty) * idx,
nomask, Val{align})
vacc = muladd(as, bs, vacc)
vptr += sizeof(Tv) * N
iptr += sizeof(Ti) * N
end
nz_size = length(sv.nzval)
acc = sum(vacc)
@inbounds for j in (nz_size - nz_size % N + 1):nz_size
acc = muladd(sv.nzval[j], ys[sv.nzind[j]], acc)
end
return acc
end
end # module
using Test
using Random
using LinearAlgebra
using SparseArrays
using Statistics
using SIMD
using BenchmarkTools
dot_simple = SparseDots.dot_simple
dot_simd = SparseDots.dot_simd
dot_ptr = SparseDots.dot_ptr
rand_spvect = SparseDots.rand_spvect
function makedata(avgdeg; p = 0.1, N = 4, align = true,
rng = Random.GLOBAL_RNG)
rng = rng isa AbstractRNG ? rng : MersenneTwister(rng)
m = ceil(Int, avgdeg / p)
sv = rand_spvect(rng, m, p; N = N, align = align)
if align
ys = valloc(eltype(sv.nzval), N, m)
randn!(rng, ys)
else
ys = randn(rng, m)
end
return (sv = sv, ys = ys)
end
@testset begin
@testset for seed in 1:10, align in [false, true]
(sv, ys) = makedata(5; rng=(seed * (align + 1)), align=align)
z0 = dot(SparseVector(sv), ys)
z1 = dot_simple(sv, ys, Val(align))
z2 = dot_simd(sv, ys, Val(align))
z3 = dot_ptr(sv, ys, Val(align))
@test z1 ≈ z0
@test z2 ≈ z0
@test z3 ≈ z0
end
end
suite = BenchmarkGroup()
sizelist = 2 .^ (4:2:14)
alignments = [false, true]
for avgdeg in sizelist
for dotfun in [dot_simple, dot_simd, dot_ptr], align in alignments
suite[nameof(dotfun), align, avgdeg] = @benchmarkable begin
sv, ys = data
$dotfun(sv, ys, $(Val(align)))
end setup=(data = makedata($avgdeg))
end
end
tune!(suite, verbose = true)
println()
results = run(suite, verbose = true)
results |> display
println()
times(results, sizelist, dotfun, align, f=median) =
[f(results[dotfun, align, n]).time for n in sizelist]
println("dot_simd / dot_simple")
println(times(results, sizelist, :dot_simd, alignments[1]) ./
times(results, sizelist, :dot_simple, alignments[1]))
println("dot_ptr / dot_simple")
println(times(results, sizelist, :dot_ptr, alignments[1]) ./
times(results, sizelist, :dot_simple, alignments[1]))
using Plots
function plot_times(results, sizelist,
alignments = alignments,
functions = [:dot_simple, :dot_simd, :dot_ptr])
plt = plot()
for dotfun in functions, align in alignments
plot!(plt, sizelist,
times(results, sizelist, dotfun, align),
marker = :o,
label = string(dotfun, align ? " aligned" : ""))
end
plot!(plt,
ylabel = "Time [ns]",
xlabel = "Average number non-zero elements",
xscale = :log10,
yscale = :log10,
legend = :topleft)
return plt
end
function plot_ratio(results, sizelist,
alignments = alignments,
functions = [:dot_simple, :dot_simd, :dot_ptr])
plt = plot()
color = length(alignments)
for dotfun in functions[2:end]
for align in alignments
baseline = times(results, sizelist, functions[1], align)
plot!(plt, sizelist,
times(results, sizelist, dotfun, align) ./ baseline,
color = (color += 1),
marker = :o,
label = string("$dotfun/$(functions[1])",
align ? " aligned" : ""))
end
end
plot!(plt, [1], seriestype=:hline, color=:black, label="")
plot!(plt,
ylabel = "Ratio to baseline",
xlabel = "Average number non-zero elements",
xscale = :log10,
legend = :top)
return plt
end
function plot_all(results, sizelist; kwargs...)
plot(
plot_times(results, sizelist),
plot_ratio(results, sizelist);
layout = (2, 1),
kwargs...)
end
if abspath(PROGRAM_FILE) == @__FILE__
pyplot() # some lines vanishes with GR
plt = plot_all(results, sizelist, size=(600, 500))
path = "benchmark.png"
@info "Saving plot to $path"
savefig(plt, path)
else
plt = plot_all(results, sizelist)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment