Created
October 3, 2016 13:42
-
-
Save lindahua/33110dfbcb1542f3742474aaf46b61af to your computer and use it in GitHub Desktop.
Simplified way to register operations (proof of concept)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// A proof-of-concept demonstration of Operation definition & registration | |
#include <functional> | |
#include <iostream> | |
#include <string> | |
#include <vector> | |
#include <utility> | |
#include <tuple> | |
#include <cassert> | |
#include <cmath> | |
#include <unordered_map> | |
struct Attr { | |
std::string name; | |
double value; | |
}; | |
using attr_list_t = std::vector<Attr>; | |
using calc_func_t = std::function<double(size_t, const double*)>; | |
// Operation that actually performs the computation | |
class Op final { | |
private: | |
calc_func_t fwd_; | |
public: | |
Op(calc_func_t f) : fwd_(f) {} | |
double forward(size_t n, const double* args) const { | |
return fwd_(n, args); | |
} | |
}; | |
using op_creator_t = std::function<Op(const attr_list_t&)>; | |
// Proto that specifies useful meta information | |
class OpProto final { | |
private: | |
std::string name_; | |
size_t ninputs_ = 0; | |
op_creator_t opcreator_; | |
public: | |
OpProto(const std::string& name) | |
: name_(name) {} | |
const std::string& name() const { | |
return name_; | |
} | |
OpProto& set_ninputs(size_t n) { | |
ninputs_ = n; | |
return *this; | |
} | |
size_t ninputs() const { | |
return ninputs_; | |
} | |
OpProto& set_opcreator(op_creator_t cf) { | |
opcreator_ = cf; | |
return *this; | |
} | |
Op createOp(const attr_list_t& attrs) const { | |
return opcreator_(attrs); | |
} | |
}; | |
// Registration facilities | |
static std::unordered_map<std::string, OpProto> registry; | |
inline OpProto& registerOp(const std::string& name) { | |
registry.emplace(name, OpProto(name)); | |
return registry.at(name); | |
} | |
inline const OpProto& getOp(const std::string& name) { | |
return registry.at(name); | |
} | |
// main | |
int main() { | |
// register operations | |
registerOp("add") | |
.set_ninputs(2) | |
.set_opcreator([](const attr_list_t&){ | |
return Op([](size_t n, const double *args){ | |
assert(n == 2); | |
return args[0] + args[1]; | |
}); | |
}); | |
registerOp("mul") | |
.set_ninputs(2) | |
.set_opcreator([](const attr_list_t&){ | |
return Op([](size_t n, const double *args){ | |
assert(n == 2); | |
return args[0] * args[1]; | |
}); | |
}); | |
registerOp("pow") | |
.set_ninputs(1) | |
.set_opcreator([](const attr_list_t& attrs){ | |
double p = 1; | |
// extract attributes | |
for (const auto& a: attrs) { | |
if (a.name == "p") { p = a.value; } | |
} | |
// create operator accordingly | |
return Op([p](size_t n, const double *args){ | |
assert(n == 1); | |
return std::pow(args[0], p); | |
}); | |
}); | |
// set a sequence of computations to be done | |
using item_t = std::tuple<OpProto, attr_list_t, std::vector<double>>; | |
std::vector<item_t> items { | |
item_t{getOp("add"), {}, {1.0, 2.0}}, | |
item_t{getOp("mul"), {}, {2.0, 3.0}}, | |
item_t{getOp("pow"), {{"p", 2.0}}, {5.0}}, | |
}; | |
// run the items | |
for (const auto& item: items) { | |
const OpProto& proto = std::get<0>(item); | |
const attr_list_t& attrs = std::get<1>(item); | |
const std::vector<double>& args = std::get<2>(item); | |
Op op = proto.createOp(attrs); | |
double r = op.forward(args.size(), args.data()); | |
std::cout << proto.name() << " --> " << r << std::endl; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment