Created
August 19, 2020 08:08
-
-
Save samskalicky/750fc456fe838325c0701e30ac5cc3c8 to your computer and use it in GitHub Desktop.
MXNet Ops
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// put in src/operator and build MXNet | |
#include "operator_common.h" | |
extern "C" int listOps() { | |
// get op registry | |
::dmlc::Registry<::nnvm::Op>* reg = ::dmlc::Registry<::nnvm::Op>::Get(); | |
// get list of registered op names | |
std::vector<std::string> ops = reg->ListAllNames(); | |
// create inverse map of Op to name (to find aliases) | |
std::map<const ::nnvm::Op*,std::vector<std::string> > op_map; | |
for(auto &name : ops) { | |
const ::nnvm::Op* op = reg->Find(name); | |
if(op_map.count(op) > 0) { | |
if(name.compare(op->name) != 0) | |
op_map[op].push_back(name); | |
} else { | |
op_map[op]={}; | |
if(name.compare(op->name) != 0) | |
op_map[op].push_back(name); | |
} | |
} | |
// print out the op mapping | |
for(auto &kv : op_map) { | |
std::cout << kv.first->name << ", "; | |
for(auto &n : kv.second) | |
std::cout << n << ", "; | |
std::cout << std::endl; | |
} | |
return 0; | |
} | |
int n = listOps(); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mxnet as mx | |
mx.base._LIB.listOps() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment