Skip to content

Instantly share code, notes, and snippets.

@orisano
Last active June 7, 2016 12:48
Show Gist options
  • Select an option

  • Save orisano/e0f8fc77783b7374ecae9d8a6de61877 to your computer and use it in GitHub Desktop.

Select an option

Save orisano/e0f8fc77783b7374ecae9d8a6de61877 to your computer and use it in GitHub Desktop.
自動微分のクラス (http://kivantium.hateblo.jp/entry/2016/03/25/010320 のパクリ)
// {{{ orliv::math::AD<T>
#ifndef INCLUDE_AD_HPP
#define INCLUDE_AD_HPP
#include <cmath>
#include <type_traits>
namespace orliv {
namespace math {
template <typename T>
struct AD {
static_assert(std::is_arithmetic<T>::value, "template arguments must be arithmetic type");
static constexpr double EPS = 1e-8;
AD(T val = 0, T d_val = 0) : val(val), d_val(d_val) {}
AD operator+(const AD& rhs) const {
AD tmp = *this;
return tmp += rhs;
}
AD& operator+=(const AD& rhs) {
val += rhs.val;
d_val += rhs.d_val;
return *this;
}
AD operator-(const AD& rhs) const {
AD tmp = *this;
return tmp -= rhs;
}
AD& operator-=(const AD& rhs) {
val -= rhs.val;
d_val -= rhs.d_val;
return *this;
}
AD operator*(const AD& rhs) const {
AD tmp = *this;
return tmp *= rhs;
}
AD& operator*=(const AD& rhs) {
d_val = d_val * rhs.val + val * rhs.d_val;
val *= rhs.val;
return *this;
}
AD operator/(const AD& rhs) const {
AD tmp = *this;
return tmp /= rhs;
}
AD& operator/=(const AD& rhs) {
val /= rhs.val;
d_val = (d_val - val * rhs.d_val) / rhs.val;
return *this;
}
AD& select() {
d_val = 1.0;
return *this;
}
AD operator+() const { return *this; }
AD operator-() const { return AD(-val, -d_val); }
T operator~() const { return d_val; }
T operator*() const { return val; }
bool operator==(const AD& rhs) const { return std::abs(val - rhs.val) < EPS; }
bool operator!=(const AD& rhs) const { return !(*this == rhs); }
bool operator<(const AD& rhs) const { return val < rhs.val; }
bool operator<=(const AD& rhs) const { return (*this == rhs) || (*this < rhs); }
bool operator>(const AD& rhs) const { return rhs < *this; }
bool operator>=(const AD& rhs) const { return !(*this < rhs); }
friend AD sqrt(const AD& x) {
T t = std::sqrt(x.val);
return AD(t, 0.5 * x.d_val / t);
}
friend AD exp(const AD& x) {
T t = std::exp(x.val);
return AD(t, x.d_val * t);
}
friend AD log(const AD& x) { return AD(std::log(x.val), x.d_val / x.val); }
friend AD sin(const AD& x) { return AD(std::sin(x.val), std::cos(x.val)); }
friend AD cos(const AD& x) { return AD(std::cos(x.val), -std::sin(x.val)); }
private:
T val, d_val;
};
}
}
#endif
// }}}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment