Last active
April 26, 2020 04:36
-
-
Save cgarciae/3080787e77c0f129ec1fcbecf3074be5 to your computer and use it in GitHub Desktop.
julia distance function + python imports
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Base.Threads | |
using Distributions | |
using BenchmarkTools | |
ENV["PYCALL_JL_RUNTIME_PYTHON"] = Sys.which("python") | |
using PyCall | |
py""" | |
import sys | |
sys.path.insert(0, ".") | |
sys.path.insert(0, ".venv/lib/python3.8/site-packages") | |
""" | |
test = pyimport("test") | |
None = [CartesianIndex()] | |
function distances_jl(data1, data2) | |
data1 = deg2rad.(data1) | |
data2 = deg2rad.(data2) | |
lat1 = @view data1[:, 1] | |
lng1 = @view data1[:, 2] | |
lat2 = @view data2[:, 1] | |
lng2 = @view data2[:, 2] | |
diff_lat = @view(lat1[:, None]) .- @view(lat2[None, :]) | |
diff_lng = @view(lng1[:, None]) .- @view(lng2[None, :]) | |
data = @. ( | |
sin(diff_lat / 2)^2 + | |
cos(@view(lat1[:, None])) * cos(lat2) * sin(diff_lng / 2)^2 | |
) | |
data .= @. 2.0 * 6373.0 * atan(sqrt(abs(data)), sqrt(abs(1.0 - data))) | |
return reshape(data, (size(data1, 1), size(data2, 1))) | |
end | |
function distances_threaded(data1, data2) | |
data1 = deg2rad.(data1) | |
data2 = deg2rad.(data2) | |
lat1 = @view data1[:, 1] | |
lng1 = @view data1[:, 2] | |
lat2 = @view data2[:, 1] | |
lng2 = @view data2[:, 2] | |
data = Matrix{Float64}(undef, length(lat1), length(lat2)) | |
@threads for i in eachindex(lat2) | |
lat2i, lng2i = lat2[i], lng2[i] | |
data[:, i] .= @. sin((lat1 - lat2i) / 2)^2 + cos(lat1) * cos(lat2i) * sin((lng1 - lng2i) / 2)^2 | |
end | |
@threads for i in eachindex(data) | |
data[i] = 2.0 * 6373.0 * atan(sqrt(abs(data[i])), sqrt(abs(1.0 - data[i]))) | |
end | |
return data | |
end | |
const a = convert(Array{Float32}, rand(Uniform(-100, 100), (5000, 2))) | |
const b = convert(Array{Float32}, rand(Uniform(-100, 100), (5000, 2))) | |
@btime distances_threaded(a, b) | |
@btime distances_jl(a, b) | |
@btime test.distances_np(a, b) | |
@btime test.distances_jax(a, b) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment