Last active
May 1, 2020 07:08
-
-
Save samskalicky/5f44e159e9f1b04237eed8d20e5d9f28 to your computer and use it in GitHub Desktop.
Example graph pass
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <math.h> | |
#include <iostream> | |
#include <algorithm> | |
#include <unordered_set> | |
#include <functional> | |
#include "lib_api.h" | |
class Node; | |
struct NodeEntry { | |
Node* node; | |
int entry; | |
}; | |
class Node { | |
public: | |
std::string op,name; | |
std::vector<NodeEntry> inputs; | |
std::vector<NodeEntry> outputs; | |
std::unordered_map<std::string, std::string> attrs; | |
}; | |
class Graph { | |
public: | |
Graph() {} | |
static Graph fromString(const std::string& json) { | |
JsonParser parser; | |
JsonVal val = parser.parse_to_json(json); | |
return fromJson(val); | |
} | |
~Graph() { | |
for(int i=0; i<nodes.size(); i++) | |
delete nodes[i]; | |
} | |
static Graph fromJson(JsonVal val) { | |
// get nodes list | |
JsonVal nodes = val.map[JsonVal("nodes")]; | |
Graph g; | |
std::map<int, Node*> nodeMap; | |
// loop over nodes | |
for(int i=0; i<nodes.list.size(); i++) { | |
Node* n = new Node(); | |
g.nodes.push_back(n); | |
JsonVal node = nodes.list[i]; | |
// set the op info | |
n->op = node.map[JsonVal("op")].str; | |
n->name = node.map[JsonVal("name")].str; | |
// if op is null its an input to the graph | |
if(n->op.compare("null") == 0) | |
g.inputs.push_back(n); | |
// set attrs | |
JsonVal attributes = node.map[JsonVal("attrs")]; | |
for(auto& kv : attributes.map) { | |
n->attrs[kv.first.str] = kv.second.str; | |
} | |
// set node inputs | |
JsonVal node_inputs = node.map[JsonVal("inputs")]; | |
n->inputs.resize(node_inputs.list.size()); | |
for(int j=0; j<node_inputs.list.size(); j++) { | |
JsonVal input = node_inputs.list[j]; | |
NodeEntry& entry = n->inputs[j]; | |
//get pointer to other node | |
entry.node = nodeMap[input.list[0].num]; | |
//get the other node's output index | |
entry.entry = input.list[1].num; | |
//set other nodes output as connected to this node | |
entry.node->outputs.push_back({n,j}); | |
} | |
nodeMap[i] = n; | |
} | |
JsonVal& heads = val.map[JsonVal("heads")]; | |
g.outputs.resize(heads.list.size()); | |
for(int i=0; i<heads.list.size(); i++) { | |
JsonVal head = heads.list[i]; | |
g.outputs[i].node = nodeMap[head.list[0].num]; | |
g.outputs[i].entry = head.list[1].num; | |
} | |
JsonParser parser; | |
for(auto& kv : val.map) { | |
if(kv.first.str.compare("nodes") != 0 && | |
kv.first.str.compare("heads") != 0 && | |
kv.first.str.compare("node_row_ptr") != 0 && | |
kv.first.str.compare("arg_nodes") != 0) { | |
g.attrs[kv.first.str] = kv.second; | |
} | |
} | |
return g; | |
} | |
JsonVal toJson() { | |
JsonVal val(MAP); | |
for(auto& kv : attrs) { | |
val.map[JsonVal(kv.first)] = kv.second; | |
} | |
std::map<Node*, int> nodeMap; | |
std::vector<Node*> sorted = topological_sort(); | |
for(int i=sorted.size()-1; i>=0; i--) { | |
nodeMap[sorted[i]] = sorted.size()-1-i; | |
} | |
val.map[JsonVal("node_row_ptr")] = JsonVal(LIST); | |
JsonVal& node_row_ptr = val.map[JsonVal("node_row_ptr")]; | |
for(int i=0; i<nodes.size(); i++) | |
node_row_ptr.list.push_back(JsonVal(i)); | |
val.map[JsonVal("arg_nodes")] = JsonVal(LIST); | |
JsonVal& arg_nodes = val.map[JsonVal("arg_nodes")]; | |
for(int i=0; i<inputs.size(); i++) | |
arg_nodes.list.push_back(JsonVal(nodeMap[inputs[i]])); | |
val.map[JsonVal("heads")] = JsonVal(LIST); | |
JsonVal& heads = val.map[JsonVal("heads")]; | |
for(int i=0; i<outputs.size(); i++) { | |
heads.list.push_back(JsonVal(LIST)); | |
JsonVal& out = heads.list[i]; | |
out.list.push_back(JsonVal(nodeMap[outputs[i].node])); | |
out.list.push_back(JsonVal(outputs[i].entry)); | |
out.list.push_back(JsonVal(0)); | |
} | |
val.map[JsonVal("nodes")] = JsonVal(LIST); | |
JsonVal& nodes_ = val.map[JsonVal("nodes")]; | |
for(int i=sorted.size()-1; i>=0; i--) { | |
nodes_.list.push_back(JsonVal(MAP)); | |
Node* n = sorted[i]; | |
JsonVal& n_ = nodes_.list[nodes_.list.size()-1]; | |
n_.map[JsonVal("op")] = JsonVal(n->op); | |
n_.map[JsonVal("name")] = JsonVal(n->name); | |
n_.map[JsonVal("inputs")] = JsonVal(LIST); | |
JsonVal& inputs_ = n_.map[JsonVal("inputs")]; | |
for(int j=0; j<n->inputs.size(); j++) { | |
inputs_.list.push_back(JsonVal(LIST)); | |
NodeEntry& entry = n->inputs[j]; | |
JsonVal& in = inputs_.list[j]; | |
in.list.push_back(JsonVal(nodeMap[entry.node])); | |
in.list.push_back(JsonVal(entry.entry)); | |
in.list.push_back(JsonVal(0)); | |
} | |
n_.map[JsonVal("attrs")] = JsonVal(MAP); | |
JsonVal& attrs_ = n_.map[JsonVal("attrs")]; | |
for(auto& kv : n->attrs) { | |
attrs_.map[JsonVal(kv.first)] = JsonVal(kv.second); | |
} | |
} | |
return val; | |
} | |
std::string toString() { | |
JsonParser parser; | |
return parser.dump(toJson()); | |
} | |
void _dfs_util(Node* n, std::unordered_set<Node*>* to_visit, | |
std::function<void(Node*)> handler) { | |
to_visit->erase(n); | |
for(NodeEntry& e : n->outputs) { | |
Node* o = e.node; | |
if(to_visit->count(o) != 0) { | |
_dfs_util(o,to_visit,handler); | |
} | |
} | |
handler(n); | |
} | |
void DFS(std::function<void(Node*)> handler) { | |
std::unordered_set<Node*> to_visit; | |
//put all nodes in set to visit | |
for(auto& n : nodes) | |
to_visit.insert(n); | |
//visit all inputs first | |
for(auto& i : inputs) | |
if(to_visit.count(i) != 0) | |
_dfs_util(i, &to_visit, handler); | |
//visit any nodes left | |
while(to_visit.size() > 0) | |
_dfs_util(*(to_visit.begin()), &to_visit, handler); | |
} | |
std::vector<Node*> topological_sort() { | |
std::vector<Node*> sorted; | |
auto handler = [&](Node* n) { | |
sorted.push_back(n); | |
}; | |
DFS(handler); | |
return sorted; | |
} | |
void print() { | |
std::cout << "########### Graph #############" << std::endl; | |
std::cout << "inputs: " << inputs.size() << std::endl; | |
std::cout << "outputs: " << outputs.size() << std::endl; | |
std::cout << "nodes: " << nodes.size() << std::endl; | |
std::vector<Node*> sorted; | |
auto handler = [&](Node* n) { | |
sorted.push_back(n); | |
}; | |
DFS(handler); | |
for(int i=sorted.size()-1; i>=0; i--) { | |
std::cout << "Node: " << sorted[i]->name << std::endl; | |
for(int j=0; j<sorted[i]->inputs.size(); j++) { | |
std::cout << "\tInput: " << sorted[i]->inputs[j].node->name << " " << sorted[i]->inputs[j].entry << std::endl; | |
} | |
for(int j=0; j<sorted[i]->outputs.size(); j++) { | |
std::cout << "\tOutput: " << sorted[i]->outputs[j].node->name << " " << sorted[i]->outputs[j].entry << std::endl; | |
} | |
} | |
std::cout << "###############################" << std::endl; | |
} | |
std::vector<Node*> nodes; | |
std::vector<Node*> inputs; | |
std::vector<NodeEntry> outputs; | |
std::map<std::string, JsonVal> attrs; | |
}; | |
/* \brief a basic pass that parses the input string to JSON and then dumps it back */ | |
MXReturnValue graphPass(const std::string& in_graph, const std::string** out_graph, | |
const std::unordered_map<std::string, std::string>& options, | |
const std::unordered_map<std::string, MXTensor>& args, | |
const std::unordered_map<std::string, MXTensor>& aux, | |
const PassResource& res) { | |
//convert graph from JSON string to Graph/Node data structure | |
Graph g = Graph::fromString(in_graph); | |
//print initial graph | |
//g.print(); | |
//create a new arg param | |
MXTensor* arg_ = res.alloc_arg("test_arg",{3,2},MXContext::CPU(0),kFloat32); | |
//find node with 'elemwise_add' op type | |
Node* add = nullptr; | |
for(Node* n : g.nodes) | |
if(n->op.compare("elemwise_add") == 0) | |
add = n; | |
//create a new input Node | |
Node* n = new Node(); | |
n->name = "test_arg"; | |
n->op = "null"; | |
//add a new node in graph | |
g.nodes.push_back(n); | |
g.inputs.push_back(n); | |
//disconnect old input from add node | |
add->inputs[0].node->outputs.clear(); | |
//disconnect add node from old input and connect add node to new input | |
add->inputs[0].node = n; | |
add->inputs[0].entry = 0; | |
//connect new input to add node | |
n->outputs.push_back({add,0}); | |
//print modified graph | |
//g.print(); | |
//convert back to JSON string from Graph/Node | |
*out_graph = new std::string(g.toString()); | |
return MX_SUCCESS; | |
} | |
REGISTER_PASS(graphPass) | |
.setBody(graphPass); | |
MXReturnValue initialize(int version) { | |
if (version >= 10700) { | |
std::cout << "MXNet version " << version << " supported" << std::endl; | |
return MX_SUCCESS; | |
} else { | |
std::cout << "MXNet version " << version << " not supported" << std::endl; | |
return MX_FAIL; | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py | |
index 14944a5..8157bd4 100644 | |
--- a/python/mxnet/symbol/symbol.py | |
+++ b/python/mxnet/symbol/symbol.py | |
@@ -1485,17 +1485,17 @@ class Symbol(SymbolBase): | |
assert isinstance(backend, str) | |
if args is None or len(args) == 0: | |
- args = [] | |
+ args_ = [] | |
args_handle = c_array(NDArrayHandle, []) | |
else: | |
- args_handle, args = self._get_ndarray_inputs('args', args, | |
+ args_handle, args_ = self._get_ndarray_inputs('args', args, | |
self.list_arguments(), False) | |
if aux is None or len(aux) == 0: | |
- aux = [] | |
+ aux_ = [] | |
aux_handle = c_array(NDArrayHandle, []) | |
else: | |
- aux_handle, aux = self._get_ndarray_inputs('aux_states', aux, | |
+ aux_handle, aux_ = self._get_ndarray_inputs('aux_states', aux, | |
self.list_auxiliary_states(), False) | |
if ctx is None: | |
ctx = current_context() | |
@@ -1517,9 +1517,9 @@ class Symbol(SymbolBase): | |
c_str(backend), | |
ctypes.c_int(ctx.device_typeid), | |
ctypes.byref(out), | |
- mx_uint(len(args)), | |
+ mx_uint(len(args_)), | |
args_handle, | |
- mx_uint(len(aux)), | |
+ mx_uint(len(aux_)), | |
aux_handle, | |
mx_uint(len(key_list)), | |
c_str_array(key_list), |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, ctypes | |
import mxnet as mx | |
from mxnet.gluon import nn | |
from mxnet import nd | |
from mxnet.base import _LIB, check_call, mx_uint, c_str, c_str_array, SymbolHandle | |
# load library | |
if (os.name=='posix'): | |
path = os.path.abspath('libpass_lib.so') | |
mx.library.load(path) | |
elif (os.name=='nt'): | |
path = os.path.abspath('libpass_lib.dll') | |
mx.library.load(path) | |
############################################### | |
# Test with not consuming params | |
############################################### | |
# example model, ops do not have args (use outputs from other ops as inputs) | |
a = mx.sym.var('a') | |
b = mx.sym.var('b') | |
c = a + b | |
d = mx.sym.exp(c) | |
sym = mx.sym.log(d) | |
def test_graph(): | |
# execute in MXNet | |
print('-------------------------------') | |
print('Testing regular MXNet execution') | |
exe = sym.bind(ctx=mx.cpu(), args={'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))}) | |
out = exe.forward() | |
print(out) | |
# Symbol optimize_for | |
# with propogating shapes/types | |
print('-------------------------------') | |
print('Testing graphPass with shapes/types') | |
args = {'a':mx.nd.ones((3,2)), 'b':mx.nd.ones((3,2))} | |
aux = {} | |
print(sym.tojson()) | |
for a in args: | |
print('%s: %s' % (a,args[a].shape)) | |
mysym2 = sym.optimize_for('graphPass',args,aux) | |
print(mysym2.tojson()) | |
exe2 = mysym2.bind(ctx=mx.cpu(), args=args) | |
out2 = exe2.forward() | |
print(out2) | |
test_graph() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment