Last active
June 12, 2016 13:06
-
-
Save higumachan/1171386b7e48418082af7edea3eba99f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <type_traits> | |
#include <cmath> | |
#define VARIABLE Variable<float, __COUNTER__> | |
template< | |
typename T, | |
int ID | |
> | |
class Variable; | |
namespace operators { | |
struct plus | |
{ | |
static float apply(float left, float right) | |
{ | |
return left + right; | |
} | |
template< | |
typename Left, | |
typename Right, | |
int ID | |
> | |
static float grad(const Variable<float, ID>& target, const Left& left, const Right& right) | |
{ | |
return left.get_grad(target) + right.get_grad(target); | |
} | |
}; | |
struct multiply | |
{ | |
static float apply(float left, float right) | |
{ | |
return left * right; | |
} | |
template< | |
typename Left, | |
typename Right, | |
int ID | |
> | |
static float grad(const Variable<float, ID>& target, const Left& left, const Right& right) | |
{ | |
return left.get_grad(target) * right.get_value() + left.get_value() * right.get_grad(target); | |
} | |
}; | |
struct sin | |
{ | |
static float apply(float operand) | |
{ | |
return std::sin(operand); | |
} | |
template< | |
typename Operand, | |
int ID | |
> | |
static float grad(const Variable<float, ID>& target, const Operand& operand) | |
{ | |
return std::cos(operand.get_value()) * operand.get_grad(target); | |
} | |
}; | |
} | |
template< | |
typename T, | |
int ID=0 | |
> | |
class Variable | |
{ | |
public: | |
static const int id = ID; | |
Variable(const T& _value) : value(_value) {} | |
T get_value() const { return value; } | |
template <int OTHER_ID> | |
T get_grad(const Variable<T, OTHER_ID>& target) const { return (target.id == ID ? 1 : 0);} | |
private: | |
T value; | |
}; | |
template< | |
typename Operator, | |
typename Operand | |
> | |
class UnaryExpression | |
{ | |
public: | |
UnaryExpression(const Operand& o) : operand(o) {} | |
float get_value() const | |
{ | |
return Operator::apply(operand.get_value()); | |
} | |
template<int ID> | |
float get_grad(const Variable<float, ID>& target) const | |
{ | |
return Operator::grad(target, operand); | |
} | |
private: | |
Operand operand; | |
}; | |
template< | |
class Left, | |
class Operator, | |
class Right | |
> | |
class BinaryExpression { | |
Left left; | |
Right right; | |
public: | |
BinaryExpression(const Left& l, const Right& r) | |
: left(l), right(r) | |
{} | |
float get_value() const | |
{ | |
return Operator::apply(left.get_value(), right.get_value()); | |
} | |
template<int ID> | |
float get_grad(const Variable<float, ID>& target) const | |
{ | |
return Operator::grad(target, left, right); | |
} | |
}; | |
template < | |
class Left, | |
class Right | |
> | |
BinaryExpression<Left, operators::plus, Right> operator+(const Left& left, const Right& right) | |
{ | |
return BinaryExpression<Left, operators::plus, Right>(left, right); | |
} | |
template < | |
class Left, | |
class Right | |
> | |
BinaryExpression<Left, operators::multiply, Right> operator*(const Left& left, const Right& right) | |
{ | |
return BinaryExpression<Left, operators::multiply, Right>(left, right); | |
} | |
template < | |
typename Operand | |
> | |
UnaryExpression<operators::sin, Operand> sin(const Operand& operand) | |
{ | |
return UnaryExpression<operators::sin, Operand>(operand); | |
} | |
int main(void) | |
{ | |
auto a = VARIABLE(1); | |
auto b = VARIABLE(2); | |
std::cout << "a=" << a.get_value() << std::endl; | |
std::cout << "a'=" << a.get_grad(a) << std::endl; | |
std::cout << "a*a'=" << (a * a).get_grad(a) << std::endl; | |
std::cout << "b=" << b.get_value() << std::endl; | |
std::cout << "b'=" << b.get_grad(b) << std::endl; | |
auto c = a + b; | |
std::cout << "c=" << c.get_value() << std::endl; | |
std::cout << "c'a=" << c.get_grad(a) << std::endl; | |
std::cout << "c'b=" << c.get_grad(b) << std::endl; | |
auto d = sin(c); | |
std::cout << "d=" << d.get_value() << std::endl; | |
auto temp = VARIABLE(3.141 / 2); | |
auto e = sin(temp); | |
std::cout << "e=" << e.get_value() << std::endl; | |
std::cout << "e'=" << e.get_grad(temp) << std::endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment