Last active
October 19, 2021 14:13
-
-
Save theogf/7ed2bec68917283c02ce01dd14382ef6 to your computer and use it in GitHub Desktop.
Benchmarking of Tullio and Distances.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Tullio | |
using Distances | |
using LinearAlgebra | |
using BenchmarkTools | |
using CUDA, CUDAKernels, KernelAbstractions # Seems to be needed by Tullio | |
using KernelFunctions | |
using Test | |
struct DotProduct end | |
(::DotProduct)(x, y) = dot(x, y) | |
(::SqEuclidean)(x, y) = sum(abs2, x) - 2 * dot(x, y) + sum(abs2, y) | |
std_pairwise(metric, x::AbstractVector, y::AbstractVector) = metric.(x, permutedims(y)) | |
std_pairwise(::SqEuclidean, x::ColVecs, y::ColVecs) = sum(abs2, x.X, dims=1) .+ sum(abs2, y.X, dims=1)' .- 2 * x.X' * y.X | |
std_pairwise(::DotProduct, x::ColVecs, y::ColVecs) = x.X'y.X | |
Distances.pairwise(m::DotProduct, x, y) = std_pairwise(m, x, y) # Since it's not a correct metric we use the std fallback | |
tullio_pairwise(metric, x::AbstractVector, y::AbstractVector) = @tullio K[i, j] := metric(x[i], y[j]) | |
tullio_pairwise(::SqEuclidean, x::ColVecs, y::ColVecs) = @tullio K[i, j] := x.X[k, i] ^ 2 - 2 * x.X[k, i] * y.X[k, j] + y.X[k, j] ^ 2 | |
tullio_pairwise(::DotProduct, x::ColVecs, y::ColVecs) = @tullio K[i, j] := x.X[k, i] * y.X[k, j] | |
D = 20 | |
N = 1000 | |
Xmat = rand(D, N) | |
Xcol = ColVecs(Xmat) | |
Xvec = collect.(eachcol(Xmat)) | |
gpuX = CUDA.rand(D, N) | |
gpuXcol = ColVecs(gpuX) | |
gpuXvec = cu.(Xvec) | |
Ymat = rand(D, N) | |
Ycol = ColVecs(Ymat) | |
Yvec = collect.(eachcol(Ymat)) | |
gpuY = CUDA.rand(D, N) | |
gpuYcol = ColVecs(gpuY) | |
gpuYvec = cu.(Yvec) | |
@testset "Test correct implementation" begin | |
for metric in (SqEuclidean(), DotProduct()) | |
@testset "$(string(metric))" begin | |
for (X, Y) in zip((Xcol, Xvec), (Ycol, Yvec)) | |
@test pairwise(metric, X, Y) ≈ std_pairwise(metric, X, Y) | |
@test pairwise(metric, X, Y) ≈ tullio_pairwise(metric, X, Y) | |
end | |
end | |
end | |
end | |
CUDA.allowscalar(false) # Disallow any scalar operations | |
@testset "GPU" begin | |
for metric in (SqEuclidean(), DotProduct()) | |
@testset "$(string(metric))" begin | |
for (X, Y) in zip((gpuXcol, ), (gpuYcol, )) | |
@test_nowarn std_pairwise(metric, X, Y) | |
@test_nowarn tullio_pairwise(metric, X, Y) | |
# @test_nowarn pairwise(metric, X, Y) | |
end | |
end | |
end | |
end | |
results_benchmark = Dict() | |
for metric in [SqEuclidean(), DotProduct()] | |
results_benchmark[metric] = Dict() | |
for type in [:col, :vec] | |
results_benchmark[metric][type] = Dict() | |
X = eval(Meta.parse("X$type")) | |
Y = eval(Meta.parse("Y$type")) | |
results_benchmark[metric][type][:b_dist] = @benchmark pairwise($metric, $X, $Y) | |
results_benchmark[metric][type][:b_std] = @benchmark std_pairwise($metric, $X, $Y) | |
results_benchmark[metric][type][:b_tullio] = @benchmark tullio_pairwise($metric, $X, $Y) | |
end | |
end | |
# Prettyyy printing | |
for metric in [SqEuclidean(), DotProduct()] | |
println("Testing metric $(metric)") | |
for type in [:col, :vec] | |
println("Testing X$type") | |
for method in [:b_dist, :b_std, :b_tullio] | |
println("Method $(method)") | |
display(results_benchmark[SqEuclidean()][type][method]) | |
end | |
println() | |
end | |
println() | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[deps] | |
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" | |
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | |
CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57" | |
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" | |
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" | |
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" | |
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | |
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Testing metric SqEuclidean(0.0)
Testing ColVecs
Testing Vec{Vec}
Testing metric DotProduct()
Testing ColVecs
Testing Vec{Vec}