Last active
June 7, 2016 12:48
-
-
Save orisano/e0f8fc77783b7374ecae9d8a6de61877 to your computer and use it in GitHub Desktop.
自動微分のクラス (http://kivantium.hateblo.jp/entry/2016/03/25/010320 のパクリ)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // {{{ orliv::math::AD<T> | |
| #ifndef INCLUDE_AD_HPP | |
| #define INCLUDE_AD_HPP | |
| #include <cmath> | |
| #include <type_traits> | |
| namespace orliv { | |
| namespace math { | |
| template <typename T> | |
| struct AD { | |
| static_assert(std::is_arithmetic<T>::value, "template arguments must be arithmetic type"); | |
| static constexpr double EPS = 1e-8; | |
| AD(T val = 0, T d_val = 0) : val(val), d_val(d_val) {} | |
| AD operator+(const AD& rhs) const { | |
| AD tmp = *this; | |
| return tmp += rhs; | |
| } | |
| AD& operator+=(const AD& rhs) { | |
| val += rhs.val; | |
| d_val += rhs.d_val; | |
| return *this; | |
| } | |
| AD operator-(const AD& rhs) const { | |
| AD tmp = *this; | |
| return tmp -= rhs; | |
| } | |
| AD& operator-=(const AD& rhs) { | |
| val -= rhs.val; | |
| d_val -= rhs.d_val; | |
| return *this; | |
| } | |
| AD operator*(const AD& rhs) const { | |
| AD tmp = *this; | |
| return tmp *= rhs; | |
| } | |
| AD& operator*=(const AD& rhs) { | |
| d_val = d_val * rhs.val + val * rhs.d_val; | |
| val *= rhs.val; | |
| return *this; | |
| } | |
| AD operator/(const AD& rhs) const { | |
| AD tmp = *this; | |
| return tmp /= rhs; | |
| } | |
| AD& operator/=(const AD& rhs) { | |
| val /= rhs.val; | |
| d_val = (d_val - val * rhs.d_val) / rhs.val; | |
| return *this; | |
| } | |
| AD& select() { | |
| d_val = 1.0; | |
| return *this; | |
| } | |
| AD operator+() const { return *this; } | |
| AD operator-() const { return AD(-val, -d_val); } | |
| T operator~() const { return d_val; } | |
| T operator*() const { return val; } | |
| bool operator==(const AD& rhs) const { return std::abs(val - rhs.val) < EPS; } | |
| bool operator!=(const AD& rhs) const { return !(*this == rhs); } | |
| bool operator<(const AD& rhs) const { return val < rhs.val; } | |
| bool operator<=(const AD& rhs) const { return (*this == rhs) || (*this < rhs); } | |
| bool operator>(const AD& rhs) const { return rhs < *this; } | |
| bool operator>=(const AD& rhs) const { return !(*this < rhs); } | |
| friend AD sqrt(const AD& x) { | |
| T t = std::sqrt(x.val); | |
| return AD(t, 0.5 * x.d_val / t); | |
| } | |
| friend AD exp(const AD& x) { | |
| T t = std::exp(x.val); | |
| return AD(t, x.d_val * t); | |
| } | |
| friend AD log(const AD& x) { return AD(std::log(x.val), x.d_val / x.val); } | |
| friend AD sin(const AD& x) { return AD(std::sin(x.val), std::cos(x.val)); } | |
| friend AD cos(const AD& x) { return AD(std::cos(x.val), -std::sin(x.val)); } | |
| private: | |
| T val, d_val; | |
| }; | |
| } | |
| } | |
| #endif | |
| // }}} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment