Last active
April 26, 2020 14:46
-
-
Save cgarciae/7fcfc95709b1d94b27010c5e00db6690 to your computer and use it in GitHub Desktop.
python distance functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import typing as tp | |
from jax import numpy as jnp | |
import jax | |
import numpy as np | |
import time | |
@jax.jit | |
def _distances_jax(data1, data2): | |
# data1, data2 are the data arrays with 2 cols and they hold | |
# lat., lng. values in those cols respectively | |
np = jnp | |
data1 = np.deg2rad(data1) | |
data2 = np.deg2rad(data2) | |
lat1 = data1[:, 0] | |
lng1 = data1[:, 1] | |
lat2 = data2[:, 0] | |
lng2 = data2[:, 1] | |
diff_lat = lat1[:, None] - lat2 | |
diff_lng = lng1[:, None] - lng2 | |
d = ( | |
np.sin(diff_lat / 2) ** 2 | |
+ np.cos(lat1[:, None]) * np.cos(lat2) * np.sin(diff_lng / 2) ** 2 | |
) | |
data = 2 * 6373 * np.arctan2(np.sqrt(np.abs(d)), np.sqrt(np.abs(1 - d))) | |
return np.array(data.reshape(data1.shape[0], data2.shape[0])) | |
def distances_jax(data1, data2): | |
return np.asarray(_distances_jax(data1, data2)) | |
def distances_np(data1, data2): | |
# data1, data2 are the data arrays with 2 cols and they hold | |
# lat., lng. values in those cols respectively | |
data1 = np.deg2rad(data1) | |
data2 = np.deg2rad(data2) | |
lat1 = data1[:, 0] | |
lng1 = data1[:, 1] | |
lat2 = data2[:, 0] | |
lng2 = data2[:, 1] | |
diff_lat = lat1[:, None] - lat2 | |
diff_lng = lng1[:, None] - lng2 | |
d = ( | |
np.sin(diff_lat / 2) ** 2 | |
+ np.cos(lat1[:, None]) * np.cos(lat2) * np.sin(diff_lng / 2) ** 2 | |
) | |
data = 2 * 6373 * np.arctan2(np.sqrt(np.abs(d)), np.sqrt(np.abs(1 - d))) | |
return data.reshape(data1.shape[0], data2.shape[0]) + 1.0 | |
if __name__ == '__main__': | |
a = np.random.uniform(-100, 100, size=(5000, 2)).astype(np.float32) | |
b = np.random.uniform(-100, 100, size=(5000, 2)).astype(np.float32) | |
t0 = time.time() | |
d = distances_np(a, b) | |
print("time np", time.time() - t0) | |
a = np.random.uniform(-100, 100, size=(5000, 2)).astype(np.float32) | |
b = np.random.uniform(-100, 100, size=(5000, 2)).astype(np.float32) | |
jnp.array([1]) | |
d = distances_jax(a, b) | |
t0 = time.time() | |
d = distances_jax(a, b) | |
print("time jax", (time.time() - t0)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment