Last active
April 17, 2017 12:49
-
-
Save nbecker/7f13da1a108e956fdcea7915b29085f2 to your computer and use it in GitHub Desktop.
logsumexp.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <numpy/arrayobject.h> | |
#include "pybind11/pybind11.h" | |
#include "pybind11/stl.h" | |
#include "xtensor/xarray.hpp" | |
#include "xtensor/xtensor.hpp" | |
#include "xtensor/xcontainer.hpp" | |
#include "xtensor/xbroadcast.hpp" | |
//#include "xtensor/xbuilder.hpp" | |
#include "xtensor/xview.hpp" | |
#include "xtensor/xeval.hpp" | |
#include "xtensor/xstridedview.hpp" | |
#include "xtensor-python/pyarray.hpp" | |
#include "xtensor-python/pytensor.hpp" | |
#include <algorithm> // ? | |
namespace py = pybind11; | |
template<class E1> | |
auto logsumexp1 (E1 const& e1) { | |
using value_type = typename std::decay_t<E1>::value_type; | |
auto max = xt::amax (e1)(); | |
return std::move (max + xt::log (xt::sum (xt::exp (e1-max)))); | |
} | |
template<class E1, class X> | |
auto logsumexp2 (const E1& e1, X const& axes) { | |
using value_type = typename std::decay_t<E1>::value_type; | |
auto max = xt::eval(xt::amax(e1, axes)); | |
auto sv = xt::slice_vector(max); | |
for (int i = 0; i < e1.dimension(); i++) | |
{ | |
if (std::find (axes.begin(), axes.end(), i) != axes.end()) | |
sv.push_back(xt::newaxis()); | |
else | |
sv.push_back(e1.shape()[i]); | |
} | |
auto max2 = xt::eval (xt::dynamic_view(max, sv)); | |
return (xt::pyarray<value_type>(max2 + xt::log(xt::sum(xt::exp(e1 - max2), axes)))); | |
} | |
template<class value_type> | |
auto normalize (xt::pyarray<value_type> const& e1) { | |
auto shape = std::vector<size_t>{e1.shape().size()-1}; | |
auto ls = logsumexp2 (e1, shape); | |
auto sv = xt::slice_vector(ls); | |
for (int i = 0; i < e1.dimension()-1; i++) | |
sv.push_back (xt::all()); | |
sv.push_back (xt::newaxis()); | |
auto ls2 = xt::dynamic_view (ls, sv); | |
return xt::pyarray<value_type> ((e1 - ls2)); | |
//return ls; | |
} | |
PYBIND11_PLUGIN (logsumexp) { | |
if (_import_array() < 0) { | |
PyErr_SetString(PyExc_ImportError, "numpy.core.multiarray failed to import"); | |
return nullptr; | |
} | |
py::module m("logsumexp", "pybind11 example plugin"); | |
m.def("logsumexp", [](xt::pyarray<double>const& x) { | |
return xt::pyarray<double> (xt::eval (logsumexp1 (x))); | |
}); | |
m.def("logsumexp", [](xt::pyarray<double>const& x, std::vector<size_t>const& ax) { | |
//return xt::pyarray<double> ( (logsumexp2 (x, ax))); | |
return logsumexp2 (x, ax); | |
}); | |
m.def("normalize", [](xt::pyarray<double>const& x) { | |
return normalize (x); | |
}); | |
return m.ptr(); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from xtensor_test.logsumexp import logsumexp, normalize | |
def logsumexp_py1 (u): | |
m = np.max (u) | |
return m + np.log (np.sum (np.exp (u - m))) | |
def logsumexp_py2 (u, axes): | |
m = np.max (u, axes) | |
slices = [slice(m.shape[i]) if i not in axes else np.newaxis for i in range (len (u.shape))] | |
m2 = m[slices] | |
#print (m.shape, u.shape, m2) | |
return m2 + np.log (np.sum (np.exp (u - m2), axes)) | |
def logsumexp_py (u, axes=None): | |
if axes == None: | |
return logsumexp_py1 (u) | |
else: | |
return logsumexp_py2 (u, axes) | |
u = np.ones (4) | |
v = logsumexp(u) | |
print (v) | |
print (logsumexp_py (u, (0,))) | |
print (logsumexp (u, (0,))) | |
print ('logsumexp (np.ones ((2,4))):', logsumexp (np.ones ((2,4)))) | |
print (logsumexp_py (np.ones ((2,4)))) | |
print ('logsumexp_py (np.ones ((2,4)), (1,))):', logsumexp_py (np.ones ((2,4)), (1,))) | |
print ('logsumexp (np.ones ((2,4)), (1,)):', logsumexp (np.ones ((2,4)), (1,))) | |
#print (np.ones ((2,4))) | |
print ('norm:', normalize (np.ones ((2,4)))) | |
from scipy.misc import logsumexp as logsumexp2 | |
print (logsumexp2 (u)) | |
def normalize2 (u): | |
print ('ls:', logsumexp2 (u, axis=-1)) | |
print ('u:', u) | |
return u - logsumexp2 (u, axis=-1)[...,np.newaxis] | |
print ('norm2:', normalize2 (np.ones ((2,4)))) | |
w = np.ones ((2,2,4)) | |
print ('logsumexp (np.ones (2,2,4)), (2,)):', logsumexp (np.ones ((2,2,4)), (2,))) | |
print ('logsumexp2 (np.ones (2,2,4)), (2,)):', logsumexp2 (np.ones ((2,2,4)), (2,))) | |
from timeit import timeit | |
u = np.ones ((2, 100000)) | |
print (timeit ('logsumexp2 (u, (1,))', 'from __main__ import logsumexp2, u', number=10)) | |
print (timeit ('logsumexp (u, (1,))', 'from __main__ import logsumexp, u', number=10)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment