|
|
|
import jax.numpy as np |
|
from jax import jit, vmap, pmap |
|
from timeit import timeit |
|
from functools import partial |
|
|
|
|
|
def bkern(kernel, func = vmap): |
|
return lambda x, y: func(lambda x1: func(lambda y1: kernel(x1, y1))(x))(y) |
|
|
|
def veckern(kernel): |
|
return bkern(kernel, vmap) |
|
|
|
def parkern(kernel): |
|
return bkern(kernel, pmap) |
|
|
|
|
|
|
|
def se(a, b): |
|
a_sumrows = np.einsum('ij,ij->i', a, a) |
|
b_sumrows = np.einsum('ij,ij->i', b, b) |
|
return a_sumrows[:, np.newaxis] + b_sumrows - 2 * a @ b.T |
|
|
|
vse = (veckern(lambda a,b:np.sum((a-b)**2))) |
|
|
|
|
|
def sm(x: np.ndarray, ax = 0) -> np.ndarray: |
|
"""Vector-wise softmax transform.""" |
|
return np.exp(x) / np.sum(np.exp(x), axis = ax) |
|
|
|
vsm = (vmap(lambda x: np.exp(x) / np.sum(np.exp(x)))) |
|
|
|
|
|
print("Squared euclidean distance:") |
|
print("vmap ",end="") |
|
%timeit vse(a, a) |
|
print("vectorized manualy ",end="") |
|
%timeit se(a, a) |
|
print("Softmax:") |
|
print("vmap ",end="") |
|
%timeit vsm(a) |
|
print("vectorized manualy ",end="") |
|
%timeit sm(a) |
|
|
|
vse, se, vsm, sm = [jit(f) for f in (vse, se, vsm, sm)] |
|
|
|
print("Squared euclidean distance:") |
|
print("vmap ",end="") |
|
%timeit vse(a, a) |
|
print("vectorized manualy ",end="") |
|
%timeit se(a, a) |
|
print("Softmax:") |
|
print("vmap ",end="") |
|
%timeit vsm(a) |
|
print("vectorized manualy ",end="") |
|
%timeit sm(a) |