Created
November 5, 2021 10:10
-
-
Save theogf/50426b2e991bba8868f6728d1325518b to your computer and use it in GitHub Desktop.
Test Tullio with Kernel Functions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Tullio | |
using Distances | |
using LinearAlgebra | |
using BenchmarkTools | |
using CUDA, CUDAKernels, KernelAbstractions | |
using Functors | |
using KernelFunctions | |
using Test | |
using Functors | |
struct DotProduct end | |
(::DotProduct)(x, y) = dot(x, y) | |
(::SqEuclidean)(x, y) = sum(abs2, x) - 2 * dot(x, y) + sum(abs2, y) | |
std_pairwise(metric, x::AbstractVector, y::AbstractVector) = metric.(x, permutedims(y)) | |
std_pairwise(::SqEuclidean, x::ColVecs, y::ColVecs) = sum(abs2, x.X, dims=1) .+ sum(abs2, y.X, dims=1)' .- 2 * x.X' * y.X | |
std_pairwise(::DotProduct, x::ColVecs, y::ColVecs) = x.X'y.X | |
Distances.pairwise(m::DotProduct, x, y) = std_pairwise(m, x, y) # Since it's not a correct metric we use the std fallback | |
tullio_pairwise(metric, x::AbstractVector, y::AbstractVector) = @tullio K[i, j] := metric(x[i], y[j]) | |
tullio_pairwise(::SqEuclidean, x::ColVecs, y::ColVecs) = @tullio K[i, j] := x.X[k, i] ^ 2 - 2 * x.X[k, i] * y.X[k, j] + y.X[k, j] ^ 2 | |
tullio_pairwise(::DotProduct, x::ColVecs, y::ColVecs) = @tullio K[i, j] := x.X[k, i] * y.X[k, j] | |
D = 20 | |
N = 1000 | |
Xmat = rand(D, N) | |
Xcol = ColVecs(Xmat) | |
Xvec = collect.(eachcol(Xmat)) | |
gpuX = CUDA.rand(D, N) | |
gpuXcol = ColVecs(gpuX) | |
gpuXvec = cu.(Xvec) | |
Ymat = rand(D, N) | |
Ycol = ColVecs(Ymat) | |
Yvec = collect.(eachcol(Ymat)) | |
gpuY = CUDA.rand(D, N) | |
gpuYcol = ColVecs(gpuY) | |
gpuYvec = cu.(Yvec) | |
@testset "Test correct implementation" begin | |
for metric in (SqEuclidean(), DotProduct()) | |
@testset "$(string(metric))" begin | |
for (X, Y) in zip((Xcol, Xvec), (Ycol, Yvec)) | |
@test pairwise(metric, X, Y) ≈ std_pairwise(metric, X, Y) | |
@test pairwise(metric, X, Y) ≈ tullio_pairwise(metric, X, Y) | |
end | |
end | |
end | |
end | |
CUDA.allowscalar(false) # Disallow any scalar operations | |
@testset "GPU" begin | |
for metric in (SqEuclidean(), DotProduct()) | |
@testset "$(string(metric))" begin | |
for (X, Y) in zip((gpuXcol, ), (gpuYcol, )) | |
@test_nowarn std_pairwise(metric, X, Y) | |
@test_nowarn tullio_pairwise(metric, X, Y) | |
# @test_nowarn pairwise(metric, X, Y) | |
end | |
end | |
end | |
end | |
results_benchmark = Dict() | |
for metric in [SqEuclidean(), DotProduct()] | |
results_benchmark[metric] = Dict() | |
for type in [:col, :vec] | |
results_benchmark[metric][type] = Dict() | |
X = eval(Meta.parse("X$type")) | |
Y = eval(Meta.parse("Y$type")) | |
results_benchmark[metric][type][:b_dist] = @benchmark pairwise($metric, $X, $Y) | |
results_benchmark[metric][type][:b_std] = @benchmark std_pairwise($metric, $X, $Y) | |
results_benchmark[metric][type][:b_tullio] = @benchmark tullio_pairwise($metric, $X, $Y) | |
end | |
end | |
# Prettyyy printing | |
for metric in [SqEuclidean(), DotProduct()] | |
println("Testing metric $(metric)") | |
for type in [:col, :vec] | |
println("Testing X$type") | |
for method in [:b_dist, :b_std, :b_tullio] | |
println("Method $(method)") | |
display(results_benchmark[SqEuclidean()][type][method]) | |
end | |
println() | |
end | |
println() | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment