-
-
Save wolfv/70246cd5888cbad9051b179b57f91f2a to your computer and use it in GitHub Desktop.
logsumexp.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <numpy/arrayobject.h> | |
#include "pybind11/pybind11.h" | |
#include "pybind11/stl.h" | |
#include "xtensor/xarray.hpp" | |
#include "xtensor/xtensor.hpp" | |
#include "xtensor/xcontainer.hpp" | |
#include "xtensor/xbroadcast.hpp" | |
//#include "xtensor/xbuilder.hpp" | |
#include "xtensor/xview.hpp" | |
#include "xtensor/xeval.hpp" | |
#include "xtensor/xstridedview.hpp" | |
#include "xtensor-python/pyarray.hpp" | |
#include "xtensor-python/pytensor.hpp" | |
#include "xtensor/xio.hpp" | |
#include <algorithm> // ? | |
namespace py = pybind11; | |
template<class E1> | |
auto logsumexp1 (E1 const& e1) { | |
using value_type = typename std::decay_t<E1>::value_type; | |
auto max = xt::amax (e1)(); | |
return std::move (max + xt::log (xt::sum (xt::exp (e1-max)))); | |
} | |
template<class E1, class X> | |
auto logsumexp2 (const E1& e1, X const& axes) { | |
using value_type = typename std::decay_t<E1>::value_type; | |
auto max = xt::eval(xt::amax(e1, axes)); | |
auto sv = xt::slice_vector(max); | |
for (int i = 0; i < e1.dimension(); i++) | |
{ | |
if (std::find (axes.begin(), axes.end(), i) != axes.end()) | |
sv.push_back(xt::newaxis()); | |
else | |
sv.push_back(e1.shape()[i]); | |
} | |
auto max2 = xt::dynamic_view(max, sv); | |
return (xt::pyarray<value_type>(max2 + xt::log(xt::sum(xt::exp(e1 - max2), axes)))); | |
} | |
template<class value_type> | |
auto normalize (xt::pytensor<value_type,2> const& e1) { | |
auto axis = std::vector<size_t>{ e1.dimension() - 1 }; | |
auto ls = logsumexp2(e1, axis); | |
auto sv = xt::slice_vector(ls); | |
sv.push_back(xt::all()); | |
sv.push_back(xt::newaxis()); | |
auto ls2 = xt::dynamic_view(ls, sv); | |
return xt::pyarray<value_type>((e1 - ls2)); | |
} | |
PYBIND11_PLUGIN (logsumexp) { | |
if (_import_array() < 0) { | |
PyErr_SetString(PyExc_ImportError, "numpy.core.multiarray failed to import"); | |
return nullptr; | |
} | |
py::module m("logsumexp", "pybind11 example plugin"); | |
m.def("logsumexp", [](xt::pyarray<double>const& x) { | |
return xt::pyarray<double>(xt::eval(logsumexp1(x))); | |
}); | |
m.def("logsumexp", [](xt::pyarray<double>const& x, std::vector<size_t>const& ax) { | |
return logsumexp2 (x, ax); | |
}); | |
m.def("normalize", [](xt::pytensor<double,2>const& x) { | |
return normalize(x); | |
}); | |
return m.ptr(); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from xtensor_test.logsumexp import logsumexp, normalize | |
def logsumexp_py1 (u): | |
m = np.max (u) | |
return m + np.log (np.sum (np.exp (u - m))) | |
def logsumexp_py2 (u, axes): | |
m = np.max (u, axes) | |
slices = [slice(m.shape[i]) if i not in axes else np.newaxis for i in range (len (u.shape))] | |
m2 = m[slices] | |
#print (m.shape, u.shape, m2) | |
return m2 + np.log (np.sum (np.exp (u - m2), axes)) | |
def logsumexp_py (u, axes=None): | |
if axes == None: | |
return logsumexp_py1 (u) | |
else: | |
return logsumexp_py2 (u, axes) | |
u = np.ones (4) | |
v = logsumexp(u) | |
print (v) | |
print (logsumexp_py (u, (0,))) | |
print (logsumexp (u, (0,))) | |
print ('logsumexp (np.ones ((2,4))):', logsumexp (np.ones ((2,4)))) | |
print (logsumexp_py (np.ones ((2,4)))) | |
print ('logsumexp_py (np.ones ((2,4)), (1,))):', logsumexp_py (np.ones ((2,4)), (1,))) | |
print ('logsumexp (np.ones ((2,4)), (1,)):', logsumexp (np.ones ((2,4)), (1,))) | |
#print (np.ones ((2,4))) | |
print ('norm:', normalize (np.ones ((2,4)))) | |
from scipy.misc import logsumexp as logsumexp2 | |
print (logsumexp2 (u)) | |
def normalize2 (u): | |
print ('ls:', logsumexp2 (u, axis=-1)) | |
print ('u:', u) | |
return u - logsumexp2 (u, axis=-1)[:,np.newaxis] | |
print ('norm2:', normalize2 (np.ones ((2,4)))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment