Created
March 30, 2013 21:25
-
-
Save kmike/5278404 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
from __future__ import absolute_import | |
from libc.math cimport log2 # FIXME: can be unavailable in Windows | |
cimport cython | |
ctypedef double dtype_t | |
DEF _NINF = float('-inf') | |
def log_add(*values): | |
""" | |
Adds the logged values, returning the logarithm of the addition. | |
Idea (for 2 variables): let's assume we want to add P1 and P2 and | |
that P1 >= P2. We have | |
log2(P1+P2) = log2(P1) + log2(1+ 2^(log2(P2) - log2(P1)) | |
So this function accepts log2(P1) and log(P2) and returns log2(P1+P2) | |
computed using the formula above. | |
""" | |
cdef double value | |
cdef double sum_diffs | |
cdef double _max = _NINF | |
# find maximum value (builtin 'max' is unrolled for speed) | |
for value in values: | |
if value > _max: | |
_max = value | |
if _max == _NINF: | |
return _max | |
sum_diffs = 0 | |
for value in values: | |
sum_diffs += 2**(value - _max) | |
return _max + log2(sum_diffs) | |
@cython.boundscheck(False) | |
cdef dtype_t _max(dtype_t[:] values): | |
# find maximum value (builtin 'max' is unrolled for speed) | |
cdef dtype_t value | |
cdef dtype_t vmax = _NINF | |
for i in range(values.shape[0]): | |
value = values[i] | |
if value > vmax: | |
vmax = value | |
return vmax | |
@cython.boundscheck(False) | |
def logsumexp2(dtype_t[:] arr): | |
cdef dtype_t vmax = _max(arr) | |
cdef dtype_t power_sum = 0 | |
for i in range(arr.shape[0]): | |
power_sum += 2**(arr[i]-vmax) | |
return log2(power_sum) + vmax |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment