Skip to content

Instantly share code, notes, and snippets.

@kmike
Created March 30, 2013 21:25
Show Gist options
  • Save kmike/5278404 to your computer and use it in GitHub Desktop.
Save kmike/5278404 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from libc.math cimport log2 # FIXME: can be unavailable in Windows
cimport cython
ctypedef double dtype_t
DEF _NINF = float('-inf')
def log_add(*values):
"""
Adds the logged values, returning the logarithm of the addition.
Idea (for 2 variables): let's assume we want to add P1 and P2 and
that P1 >= P2. We have
log2(P1+P2) = log2(P1) + log2(1+ 2^(log2(P2) - log2(P1))
So this function accepts log2(P1) and log(P2) and returns log2(P1+P2)
computed using the formula above.
"""
cdef double value
cdef double sum_diffs
cdef double _max = _NINF
# find maximum value (builtin 'max' is unrolled for speed)
for value in values:
if value > _max:
_max = value
if _max == _NINF:
return _max
sum_diffs = 0
for value in values:
sum_diffs += 2**(value - _max)
return _max + log2(sum_diffs)
@cython.boundscheck(False)
cdef dtype_t _max(dtype_t[:] values):
# find maximum value (builtin 'max' is unrolled for speed)
cdef dtype_t value
cdef dtype_t vmax = _NINF
for i in range(values.shape[0]):
value = values[i]
if value > vmax:
vmax = value
return vmax
@cython.boundscheck(False)
def logsumexp2(dtype_t[:] arr):
cdef dtype_t vmax = _max(arr)
cdef dtype_t power_sum = 0
for i in range(arr.shape[0]):
power_sum += 2**(arr[i]-vmax)
return log2(power_sum) + vmax
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment