Skip to content

Instantly share code, notes, and snippets.

@higumachan
Last active June 12, 2016 13:06
Show Gist options
  • Save higumachan/1171386b7e48418082af7edea3eba99f to your computer and use it in GitHub Desktop.
Save higumachan/1171386b7e48418082af7edea3eba99f to your computer and use it in GitHub Desktop.
#include <iostream>
#include <type_traits>
#include <cmath>
#define VARIABLE Variable<float, __COUNTER__>
template<
typename T,
int ID
>
class Variable;
namespace operators {
struct plus
{
static float apply(float left, float right)
{
return left + right;
}
template<
typename Left,
typename Right,
int ID
>
static float grad(const Variable<float, ID>& target, const Left& left, const Right& right)
{
return left.get_grad(target) + right.get_grad(target);
}
};
struct multiply
{
static float apply(float left, float right)
{
return left * right;
}
template<
typename Left,
typename Right,
int ID
>
static float grad(const Variable<float, ID>& target, const Left& left, const Right& right)
{
return left.get_grad(target) * right.get_value() + left.get_value() * right.get_grad(target);
}
};
struct sin
{
static float apply(float operand)
{
return std::sin(operand);
}
template<
typename Operand,
int ID
>
static float grad(const Variable<float, ID>& target, const Operand& operand)
{
return std::cos(operand.get_value()) * operand.get_grad(target);
}
};
}
template<
typename T,
int ID=0
>
class Variable
{
public:
static const int id = ID;
Variable(const T& _value) : value(_value) {}
T get_value() const { return value; }
template <int OTHER_ID>
T get_grad(const Variable<T, OTHER_ID>& target) const { return (target.id == ID ? 1 : 0);}
private:
T value;
};
template<
typename Operator,
typename Operand
>
class UnaryExpression
{
public:
UnaryExpression(const Operand& o) : operand(o) {}
float get_value() const
{
return Operator::apply(operand.get_value());
}
template<int ID>
float get_grad(const Variable<float, ID>& target) const
{
return Operator::grad(target, operand);
}
private:
Operand operand;
};
template<
class Left,
class Operator,
class Right
>
class BinaryExpression {
Left left;
Right right;
public:
BinaryExpression(const Left& l, const Right& r)
: left(l), right(r)
{}
float get_value() const
{
return Operator::apply(left.get_value(), right.get_value());
}
template<int ID>
float get_grad(const Variable<float, ID>& target) const
{
return Operator::grad(target, left, right);
}
};
template <
class Left,
class Right
>
BinaryExpression<Left, operators::plus, Right> operator+(const Left& left, const Right& right)
{
return BinaryExpression<Left, operators::plus, Right>(left, right);
}
template <
class Left,
class Right
>
BinaryExpression<Left, operators::multiply, Right> operator*(const Left& left, const Right& right)
{
return BinaryExpression<Left, operators::multiply, Right>(left, right);
}
template <
typename Operand
>
UnaryExpression<operators::sin, Operand> sin(const Operand& operand)
{
return UnaryExpression<operators::sin, Operand>(operand);
}
int main(void)
{
auto a = VARIABLE(1);
auto b = VARIABLE(2);
std::cout << "a=" << a.get_value() << std::endl;
std::cout << "a'=" << a.get_grad(a) << std::endl;
std::cout << "a*a'=" << (a * a).get_grad(a) << std::endl;
std::cout << "b=" << b.get_value() << std::endl;
std::cout << "b'=" << b.get_grad(b) << std::endl;
auto c = a + b;
std::cout << "c=" << c.get_value() << std::endl;
std::cout << "c'a=" << c.get_grad(a) << std::endl;
std::cout << "c'b=" << c.get_grad(b) << std::endl;
auto d = sin(c);
std::cout << "d=" << d.get_value() << std::endl;
auto temp = VARIABLE(3.141 / 2);
auto e = sin(temp);
std::cout << "e=" << e.get_value() << std::endl;
std::cout << "e'=" << e.get_grad(temp) << std::endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment