Created
November 8, 2017 20:07
-
-
Save jamesr66a/3d0c923b3cf61d0dd5f00a2b294526a9 to your computer and use it in GitHub Desktop.
Invariant hoisting algorithm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "onnx.pb.h" | |
#include <algorithm> | |
#include <fstream> | |
#include <iostream> | |
#include <numeric> | |
#include <queue> | |
#include <tuple> | |
#include <unordered_map> | |
#include <unordered_set> | |
#include <google/protobuf/text_format.h> | |
namespace onnx { | |
struct ReaderWriterRecord { | |
std::unordered_set<int> readers; | |
std::unordered_set<int> writers; | |
}; | |
using EdgeMap = std::unordered_map<std::string, ReaderWriterRecord>; | |
// Output: map from blob name to (reader op indices, writer op indices) | |
EdgeMap createEdgeMap(const GraphProto &gp) { | |
EdgeMap retval; | |
for (int i = 0; i < gp.node_size(); ++i) { | |
auto &node = gp.node(i); | |
for (const auto &in : node.input()) { | |
retval[in].readers.insert(i); | |
} | |
for (const auto &out : node.output()) { | |
retval[out].writers.insert(i); | |
} | |
} | |
return retval; | |
} | |
void markLiveInputNodes(const GraphProto &gp, const EdgeMap &em, | |
std::unordered_set<int> *live_nodes) { | |
std::unordered_set<std::string> initializers; | |
for (const auto& i : gp.initializer()) { | |
initializers.insert(i.name()); | |
} | |
std::unordered_set<std::string> real_inputs; | |
for (const auto& i : gp.input()) { | |
if (initializers.find(i.name()) == initializers.end()) { | |
real_inputs.insert(i.name()); | |
} | |
} | |
for (const auto &in : real_inputs) { | |
// BFS from each input | |
std::queue<int> frontier; | |
std::unordered_set<int> seen; | |
// TODO: assert has name | |
for (const int adj : em.at(in).readers) { | |
frontier.push(adj); | |
live_nodes->insert(adj); | |
} | |
while (!frontier.empty()) { | |
int next = frontier.front(); | |
frontier.pop(); | |
if (seen.find(next) != seen.end()) { | |
continue; | |
} | |
seen.insert(next); | |
for (const auto &out : gp.node(next).output()) { | |
for (const int adj : em.at(out).readers) { | |
frontier.push(adj); | |
live_nodes->insert(adj); | |
} | |
} | |
} | |
} | |
} | |
void markDeadCode(const GraphProto &gp, const EdgeMap &em, | |
std::unordered_set<int> *dead_nodes) { | |
} | |
std::tuple<GraphProto, GraphProto> partitionImmutableOps(const GraphProto &gp) { | |
std::cout << "creating edge map\n"; | |
EdgeMap em = createEdgeMap(gp); | |
std::unordered_set<int> live_nodes; | |
std::cout << "mark live input nodes\n"; | |
markLiveInputNodes(gp, em, &live_nodes); | |
for (const auto& x : live_nodes) { | |
std::cout << x << ", "; | |
} | |
std::cout << std::endl; | |
#if 0 | |
std::vector<int> indices(gp.node_size()); | |
std::iota(indices.begin(), indices.end(), 0); | |
std::unordered_set<int> indices_copy(indices.begin(), indices.end()); | |
std::vector<int> a; | |
std::set_difference(indices_copy.begin(), indices_copy.end(), live_nodes.begin(), live_nodes.end(), std::inserter(a, a.begin())); | |
#endif | |
std::vector<int> a; | |
for (int i=0; i<gp.node_size(); ++i) { | |
if (live_nodes.find(i) == live_nodes.end()) { | |
a.push_back(i); | |
} | |
} | |
std::unordered_set<int> imm(a.begin(), a.end()); | |
GraphProto new_proto; | |
new_proto.CopyFrom(gp); | |
new_proto.clear_node(); | |
for (int i=0; i<gp.node_size(); ++i) { | |
if (imm.find(i) != imm.end()) { | |
continue; | |
} | |
*new_proto.add_node() = gp.node(i); | |
} | |
std::cout << "Immutable ops" << std::endl; | |
for (const auto& x : a) { | |
std::cout << x << ", "; | |
} | |
std::cout << std::endl; | |
for (const auto& i : live_nodes) { | |
auto& node = gp.node(i); | |
std::string out; | |
google::protobuf::TextFormat::PrintToString(node, &out); | |
std::cout << out << std::endl; | |
} | |
return std::tuple<GraphProto, GraphProto>(GraphProto(), new_proto); | |
} | |
} // namespace onnx | |
#include <google/protobuf/io/coded_stream.h> | |
#include <google/protobuf/io/zero_copy_stream_impl_lite.h> | |
template <typename Proto> | |
bool ParseProtoFromBytes(Proto* proto, const char* buffer, size_t length) { | |
// Total bytes hard limit / warning limit are set to 1GB and 512MB | |
// respectively. | |
::google::protobuf::io::CodedInputStream coded_stream( | |
new google::protobuf::io::ArrayInputStream(buffer, length)); | |
coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20); | |
return proto->ParseFromCodedStream(&coded_stream); | |
} | |
int main(int argc, char **argv) { | |
std::string infile; | |
if (argc >= 2) { | |
infile = argv[1]; | |
} else { | |
std::cerr << "please specify model proto file" << std::endl; | |
return -1; | |
} | |
char *stream = (char*)malloc(1 * 1000 * 1000 * 1000); | |
std::fstream input(infile, std::ios::in | std::ios::binary); | |
input.read(stream, 1*1000*1000*1000); | |
onnx::ModelProto model; | |
ParseProtoFromBytes<onnx::ModelProto>(&model, stream, 1*1000*1000*1000); | |
for (const auto& x : model.graph().input()) { | |
std::cout << x.name() << std::endl; | |
} | |
onnx::GraphProto init, predict; | |
std::tie(init, predict) = partitionImmutableOps(model.graph()); | |
onnx::ModelProto nmp; | |
nmp.CopyFrom(model); | |
*nmp.mutable_graph() = predict; | |
std::fstream out("culled_model.pb", std::ios::out | std::ios::trunc | std::ios::binary); | |
if (!nmp.SerializeToOstream(&out)) { | |
std::cerr << "Failed to serialize\n"; | |
return -1; | |
} | |
free(stream); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment