Skip to content

Instantly share code, notes, and snippets.

@theogf
Last active October 19, 2021 14:13
Show Gist options
  • Save theogf/7ed2bec68917283c02ce01dd14382ef6 to your computer and use it in GitHub Desktop.
Save theogf/7ed2bec68917283c02ce01dd14382ef6 to your computer and use it in GitHub Desktop.
Benchmarking of Tullio and Distances.jl
using Tullio
using Distances
using LinearAlgebra
using BenchmarkTools
using CUDA, CUDAKernels, KernelAbstractions # Seems to be needed by Tullio
using KernelFunctions
using Test
struct DotProduct end
(::DotProduct)(x, y) = dot(x, y)
(::SqEuclidean)(x, y) = sum(abs2, x) - 2 * dot(x, y) + sum(abs2, y)
std_pairwise(metric, x::AbstractVector, y::AbstractVector) = metric.(x, permutedims(y))
std_pairwise(::SqEuclidean, x::ColVecs, y::ColVecs) = sum(abs2, x.X, dims=1) .+ sum(abs2, y.X, dims=1)' .- 2 * x.X' * y.X
std_pairwise(::DotProduct, x::ColVecs, y::ColVecs) = x.X'y.X
Distances.pairwise(m::DotProduct, x, y) = std_pairwise(m, x, y) # Since it's not a correct metric we use the std fallback
tullio_pairwise(metric, x::AbstractVector, y::AbstractVector) = @tullio K[i, j] := metric(x[i], y[j])
tullio_pairwise(::SqEuclidean, x::ColVecs, y::ColVecs) = @tullio K[i, j] := x.X[k, i] ^ 2 - 2 * x.X[k, i] * y.X[k, j] + y.X[k, j] ^ 2
tullio_pairwise(::DotProduct, x::ColVecs, y::ColVecs) = @tullio K[i, j] := x.X[k, i] * y.X[k, j]
D = 20
N = 1000
Xmat = rand(D, N)
Xcol = ColVecs(Xmat)
Xvec = collect.(eachcol(Xmat))
gpuX = CUDA.rand(D, N)
gpuXcol = ColVecs(gpuX)
gpuXvec = cu.(Xvec)
Ymat = rand(D, N)
Ycol = ColVecs(Ymat)
Yvec = collect.(eachcol(Ymat))
gpuY = CUDA.rand(D, N)
gpuYcol = ColVecs(gpuY)
gpuYvec = cu.(Yvec)
@testset "Test correct implementation" begin
for metric in (SqEuclidean(), DotProduct())
@testset "$(string(metric))" begin
for (X, Y) in zip((Xcol, Xvec), (Ycol, Yvec))
@test pairwise(metric, X, Y) ≈ std_pairwise(metric, X, Y)
@test pairwise(metric, X, Y) ≈ tullio_pairwise(metric, X, Y)
end
end
end
end
CUDA.allowscalar(false) # Disallow any scalar operations
@testset "GPU" begin
for metric in (SqEuclidean(), DotProduct())
@testset "$(string(metric))" begin
for (X, Y) in zip((gpuXcol, ), (gpuYcol, ))
@test_nowarn std_pairwise(metric, X, Y)
@test_nowarn tullio_pairwise(metric, X, Y)
# @test_nowarn pairwise(metric, X, Y)
end
end
end
end
results_benchmark = Dict()
for metric in [SqEuclidean(), DotProduct()]
results_benchmark[metric] = Dict()
for type in [:col, :vec]
results_benchmark[metric][type] = Dict()
X = eval(Meta.parse("X$type"))
Y = eval(Meta.parse("Y$type"))
results_benchmark[metric][type][:b_dist] = @benchmark pairwise($metric, $X, $Y)
results_benchmark[metric][type][:b_std] = @benchmark std_pairwise($metric, $X, $Y)
results_benchmark[metric][type][:b_tullio] = @benchmark tullio_pairwise($metric, $X, $Y)
end
end
# Prettyyy printing
for metric in [SqEuclidean(), DotProduct()]
println("Testing metric $(metric)")
for type in [:col, :vec]
println("Testing X$type")
for method in [:b_dist, :b_std, :b_tullio]
println("Method $(method)")
display(results_benchmark[SqEuclidean()][type][method])
end
println()
end
println()
end
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
@theogf
Copy link
Author

theogf commented Oct 19, 2021

Testing metric SqEuclidean(0.0)

Testing ColVecs

Method b_dist
BenchmarkTools.Trial: 148 samples with 1 evaluation.
 Range (min  max):  31.190 ms  45.849 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     33.198 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   33.869 ms ±  2.599 ms  ┊ GC (mean ± σ):  0.59% ± 1.69%

  █ ▃ ▃▄ ▄▂ ▂ ▄ ▃                                              
  █▆███████▆█▆█▅█▄▇▄▄▃▄▁▄▅▃▄▅▃▃▁▃▃▄▁▁▁▃▄▁▃▃▁▁▁▃▁▁▁▁▃▃▁▄▁▁▁▁▁▃ ▃
  31.2 ms         Histogram: frequency by time        42.9 ms <

 Memory estimate: 7.63 MiB, allocs estimate: 2.
Method b_std
BenchmarkTools.Trial: 1499 samples with 1 evaluation.
 Range (min  max):  2.291 ms    9.082 ms  ┊ GC (min  max):  0.00%  26.23%
 Time  (median):     2.913 ms               ┊ GC (median):     0.00%
 Time  (mean ± σ):   3.301 ms ± 917.501 μs  ┊ GC (mean ± σ):  12.53% ± 16.70%

     ▂▂▅██▅▄▅▂▁                                                
  ▂▃▆███████████▇▅▄▅▃▄▄▃▃▃▃▂▂▂▂▂▂▁▁▂▂▃▃▄▅▄▄▄▅▅▄▄▃▃▃▄▃▂▃▂▃▂▂▂▂ ▄
  2.29 ms         Histogram: frequency by time        5.84 ms <

 Memory estimate: 15.43 MiB, allocs estimate: 8.
Method b_tullio
BenchmarkTools.Trial: 1088 samples with 1 evaluation.
 Range (min  max):  3.691 ms  15.717 ms  ┊ GC (min  max): 0.00%  18.27%
 Time  (median):     4.164 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   4.566 ms ±  1.110 ms  ┊ GC (mean ± σ):  4.41% ±  8.98%

  ▁▅▇██▇▇▆▅▂▃ ▁▁  ▁      ▂        ▂▁▁                        ▁
  ██████████████▇▆█▇▆▆▆▁▇██▇█▆▅█▆▆███▇█▇▆▆▇█▅▁▄▄▁▄▁▁▁▄▄▁▁▄▆▅ █
  3.69 ms      Histogram: log(frequency) by time     8.72 ms <

 Memory estimate: 7.63 MiB, allocs estimate: 27.

Testing Vec{Vec}

Method b_dist
BenchmarkTools.Trial: 213 samples with 1 evaluation.
 Range (min  max):  22.657 ms   27.813 ms  ┊ GC (min  max): 0.00%  10.85%
 Time  (median):     23.216 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   23.489 ms ± 845.840 μs  ┊ GC (mean ± σ):  0.85% ±  2.38%

   ▆▄█▆▆ ▃█▅ ▃▁                                                 
  ▇█████▇██████▆▄▃▄▁▃█▄▄▁▆▄▃▃▁▃▃▁▄▁▄▄▁▄▁▃▁▃▄▄▁▃▃▁▁▁▁▁▁▁▁▁▃▁▃▁▃ ▄
  22.7 ms         Histogram: frequency by time         26.5 ms <

 Memory estimate: 7.63 MiB, allocs estimate: 2.
Method b_std
BenchmarkTools.Trial: 207 samples with 1 evaluation.
 Range (min  max):  22.968 ms  31.026 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     23.942 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   24.224 ms ±  1.030 ms  ┊ GC (mean ± σ):  0.78% ± 2.23%

     ▂  ▃▇▄▇█ ▇▇▅  ▁                                           
  ▆▅▇██▃██████████▇█▄▇▄▅▄▅▅▄▄▆▅▁▆▃▃▃▅▁▃▄▃▄▄▃▁▁▄▁▃▃▁▃▁▁▃▁▁▁▁▁▄ ▄
  23 ms           Histogram: frequency by time        27.4 ms <

 Memory estimate: 7.63 MiB, allocs estimate: 4.
Method b_tullio
BenchmarkTools.Trial: 529 samples with 1 evaluation.
 Range (min  max):  8.011 ms  15.590 ms  ┊ GC (min  max): 0.00%  15.71%
 Time  (median):     8.784 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   9.435 ms ±  1.511 ms  ┊ GC (mean ± σ):  2.18% ±  5.32%

    ▃█▆▆▅ ▁                                                   
  ▄▇█████▇█▇▅▅▃▃▃▄▄▃▂▄▃▂▃▃▃▃▃▃▂▃▃▃▃▂▃▃▁▃▁▃▃▂▅▅▃▂▁▁▁▁▁▁▁▁▂▂▁▂ ▃
  8.01 ms        Histogram: frequency by time        14.8 ms <

 Memory estimate: 7.63 MiB, allocs estimate: 30.

Testing metric DotProduct()

Testing ColVecs

Method b_dist
BenchmarkTools.Trial: 148 samples with 1 evaluation.
 Range (min  max):  31.190 ms  45.849 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     33.198 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   33.869 ms ±  2.599 ms  ┊ GC (mean ± σ):  0.59% ± 1.69%

  █ ▃ ▃▄ ▄▂ ▂ ▄ ▃                                              
  █▆███████▆█▆█▅█▄▇▄▄▃▄▁▄▅▃▄▅▃▃▁▃▃▄▁▁▁▃▄▁▃▃▁▁▁▃▁▁▁▁▃▃▁▄▁▁▁▁▁▃ ▃
  31.2 ms         Histogram: frequency by time        42.9 ms <

 Memory estimate: 7.63 MiB, allocs estimate: 2.
Method b_std
BenchmarkTools.Trial: 1499 samples with 1 evaluation.
 Range (min  max):  2.291 ms    9.082 ms  ┊ GC (min  max):  0.00%  26.23%
 Time  (median):     2.913 ms               ┊ GC (median):     0.00%
 Time  (mean ± σ):   3.301 ms ± 917.501 μs  ┊ GC (mean ± σ):  12.53% ± 16.70%

     ▂▂▅██▅▄▅▂▁                                                
  ▂▃▆███████████▇▅▄▅▃▄▄▃▃▃▃▂▂▂▂▂▂▁▁▂▂▃▃▄▅▄▄▄▅▅▄▄▃▃▃▄▃▂▃▂▃▂▂▂▂ ▄
  2.29 ms         Histogram: frequency by time        5.84 ms <

 Memory estimate: 15.43 MiB, allocs estimate: 8.
Method b_tullio
BenchmarkTools.Trial: 1088 samples with 1 evaluation.
 Range (min  max):  3.691 ms  15.717 ms  ┊ GC (min  max): 0.00%  18.27%
 Time  (median):     4.164 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   4.566 ms ±  1.110 ms  ┊ GC (mean ± σ):  4.41% ±  8.98%

  ▁▅▇██▇▇▆▅▂▃ ▁▁  ▁      ▂        ▂▁▁                        ▁
  ██████████████▇▆█▇▆▆▆▁▇██▇█▆▅█▆▆███▇█▇▆▆▇█▅▁▄▄▁▄▁▁▁▄▄▁▁▄▆▅ █
  3.69 ms      Histogram: log(frequency) by time     8.72 ms <

 Memory estimate: 7.63 MiB, allocs estimate: 27.

Testing Vec{Vec}

Method b_dist
BenchmarkTools.Trial: 213 samples with 1 evaluation.
 Range (min  max):  22.657 ms   27.813 ms  ┊ GC (min  max): 0.00%  10.85%
 Time  (median):     23.216 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   23.489 ms ± 845.840 μs  ┊ GC (mean ± σ):  0.85% ±  2.38%

   ▆▄█▆▆ ▃█▅ ▃▁                                                 
  ▇█████▇██████▆▄▃▄▁▃█▄▄▁▆▄▃▃▁▃▃▁▄▁▄▄▁▄▁▃▁▃▄▄▁▃▃▁▁▁▁▁▁▁▁▁▃▁▃▁▃ ▄
  22.7 ms         Histogram: frequency by time         26.5 ms <

 Memory estimate: 7.63 MiB, allocs estimate: 2.
Method b_std
BenchmarkTools.Trial: 207 samples with 1 evaluation.
 Range (min  max):  22.968 ms  31.026 ms  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     23.942 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   24.224 ms ±  1.030 ms  ┊ GC (mean ± σ):  0.78% ± 2.23%

     ▂  ▃▇▄▇█ ▇▇▅  ▁                                           
  ▆▅▇██▃██████████▇█▄▇▄▅▄▅▅▄▄▆▅▁▆▃▃▃▅▁▃▄▃▄▄▃▁▁▄▁▃▃▁▃▁▁▃▁▁▁▁▁▄ ▄
  23 ms           Histogram: frequency by time        27.4 ms <

 Memory estimate: 7.63 MiB, allocs estimate: 4.
Method b_tullio
BenchmarkTools.Trial: 529 samples with 1 evaluation.
 Range (min  max):  8.011 ms  15.590 ms  ┊ GC (min  max): 0.00%  15.71%
 Time  (median):     8.784 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   9.435 ms ±  1.511 ms  ┊ GC (mean ± σ):  2.18% ±  5.32%

    ▃█▆▆▅ ▁                                                   
  ▄▇█████▇█▇▅▅▃▃▃▄▄▃▂▄▃▂▃▃▃▃▃▃▂▃▃▃▃▂▃▃▁▃▁▃▃▂▅▅▃▂▁▁▁▁▁▁▁▁▂▂▁▂ ▃
  8.01 ms        Histogram: frequency by time        14.8 ms <

 Memory estimate: 7.63 MiB, allocs estimate: 30.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment