Created
April 16, 2012 18:17
-
-
Save redpony/2400470 to your computer and use it in GitHub Desktop.
C++ class to represent real numbers in the log domain
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef LOGVAL_H_ | |
#define LOGVAL_H_ | |
// represent values internally in the log domain (much larger effective range than float, | |
// double, or long double). useful for very small probabilities or very large unnormalized | |
// probabilities while avoiding underflows/overflows. | |
// | |
// this should be used as if it were a double or float, i.e. for doubles a and b: | |
// a * b = LogVal(a) * LogVal(b) | |
// a + b = LogVal(a) + LogVal(b) | |
// | |
// Placed by Chris Dyer <[email protected]> in the public domain on April 16, 2012. | |
// | |
#include <iostream> | |
#include <cstdlib> | |
#include <cmath> | |
#include <limits> | |
#include <cassert> | |
template <class T> | |
class LogVal { | |
public: | |
typedef LogVal<T> Self; | |
LogVal() : s_(), v_(std::log(T())) {} | |
LogVal(double x) : s_(std::signbit(x)), v_(s_ ? std::log(-x) : std::log(x)) {} | |
const Self& operator=(double x) { s_ = std::signbit(x); v_ = s_ ? std::log(-x) : std::log(x); return *this; } | |
LogVal(double lnx,bool sign) : s_(sign),v_(lnx) {} | |
static Self exp(T lnx) { return Self(lnx,false); } | |
static Self One() { return Self(1); } | |
static Self Zero() { return Self(); } | |
static Self e() { return Self(1,false); } | |
void logeq(const T& v) { s_ = false; v_ = v; } | |
// true is negative, false is positive | |
bool signbit() const { | |
return s_; | |
} | |
Self& operator+=(const Self& a) { | |
if (a == Zero()) return *this; | |
if (a.s_ == s_) { | |
if (a.v_ < v_) { | |
v_ = v_ + log1p(std::exp(a.v_ - v_)); | |
} else { | |
v_ = a.v_ + log1p(std::exp(v_ - a.v_)); | |
} | |
} else { | |
if (a.v_ < v_) { | |
v_ = v_ + log1p(-std::exp(a.v_ - v_)); | |
} else { | |
v_ = a.v_ + log1p(-std::exp(v_ - a.v_)); | |
s_ = !s_; | |
} | |
} | |
return *this; | |
} | |
Self& operator*=(const Self& a) { | |
s_ = (s_ != a.s_); | |
v_ += a.v_; | |
return *this; | |
} | |
Self& operator/=(const Self& a) { | |
s_ = (s_ != a.s_); | |
v_ -= a.v_; | |
return *this; | |
} | |
Self& operator-=(const Self& a) { | |
Self b = a; | |
b.negate(); | |
return *this += b; | |
} | |
friend Self abslog(Self x) { | |
if (x.v_<0) x.v_=-x.v_; | |
return x; | |
} | |
Self& poweq(const T& power) { | |
if (s_) { | |
std::cerr << "poweq(T) not implemented when s_ is true\n"; | |
std::abort(); | |
} else | |
v_ *= power; | |
return *this; | |
} | |
//remember, s_ means negative. | |
inline bool lt(Self const& o) const { | |
return s_==o.s_ ? v_ < o.v_ : s_ > o.s_; | |
} | |
inline bool gt(Self const& o) const { | |
return s_==o.s_ ? o.v_ < v_ : s_ < o.s_; | |
} | |
Self operator-() const { | |
return Self(v_,!s_); | |
} | |
void negate() { s_ = !s_; } | |
Self inverse() const { return Self(-v_,s_); } | |
Self pow(const T& power) const { | |
Self res = *this; | |
res.poweq(power); | |
return res; | |
} | |
Self root(const T& root) const { | |
return pow(1/root); | |
} | |
T as_float() const { | |
if (s_) return -std::exp(v_); else return std::exp(v_); | |
} | |
bool s_; | |
T v_; | |
}; | |
template <class T> | |
inline std::ostream& operator<<(std::ostream& os, const LogVal<T>& v) { | |
if (v.s_) os<<"(-)"; | |
return os<<v.v_; | |
} | |
template<class T> | |
LogVal<T> operator+(LogVal<T> o1, const LogVal<T>& o2) { | |
o1 += o2; | |
return o1; | |
} | |
template<class T> | |
LogVal<T> operator*(LogVal<T> o1, const LogVal<T>& o2) { | |
o1 *= o2; | |
return o1; | |
} | |
template<class T> | |
LogVal<T> operator/(LogVal<T> o1, const LogVal<T>& o2) { | |
o1 /= o2; | |
return o1; | |
} | |
template<class T> | |
LogVal<T> operator-(LogVal<T> o1, const LogVal<T>& o2) { | |
o1 -= o2; | |
return o1; | |
} | |
template<class T> | |
T log(const LogVal<T>& o) { | |
if (o.s_) return log(-1.0); | |
return o.v_; | |
} | |
template<class T> | |
inline bool signbit(const LogVal<T>& x) { return x.signbit(); } | |
template<class T> | |
inline LogVal<T> abs(const LogVal<T>& o) { | |
if (o.s_) { | |
LogVal<T> res = o; | |
res.s_ = false; | |
return res; | |
} else { return o; } | |
} | |
template <class T> | |
inline LogVal<T> pow(const LogVal<T>& b, const T& e) { | |
return b.pow(e); | |
} | |
template <class T> | |
bool operator==(const LogVal<T>& lhs, const LogVal<T>& rhs) { | |
return (lhs.v_ == rhs.v_) && (lhs.s_ == rhs.s_); | |
} | |
template <class T> | |
bool operator!=(const LogVal<T>& lhs, const LogVal<T>& rhs) { | |
return !(lhs == rhs); | |
} | |
template <class T> | |
bool operator<(const LogVal<T>& lhs, const LogVal<T>& rhs) { | |
if (lhs.s_ == rhs.s_) { | |
return (lhs.v_ < rhs.v_); | |
} else { | |
return lhs.s_ > rhs.s_; | |
} | |
} | |
template <class T> | |
bool operator<=(const LogVal<T>& lhs, const LogVal<T>& rhs) { | |
return (lhs < rhs) || (lhs == rhs); | |
} | |
template <class T> | |
bool operator>(const LogVal<T>& lhs, const LogVal<T>& rhs) { | |
return !(lhs <= rhs); | |
} | |
template <class T> | |
bool operator>=(const LogVal<T>& lhs, const LogVal<T>& rhs) { | |
return !(lhs < rhs); | |
} | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment