Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created November 8, 2017 20:07
Show Gist options
  • Save jamesr66a/3d0c923b3cf61d0dd5f00a2b294526a9 to your computer and use it in GitHub Desktop.
Save jamesr66a/3d0c923b3cf61d0dd5f00a2b294526a9 to your computer and use it in GitHub Desktop.
Invariant hoisting algorithm
#include "onnx.pb.h"
#include <algorithm>
#include <fstream>
#include <iostream>
#include <numeric>
#include <queue>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <google/protobuf/text_format.h>
namespace onnx {
struct ReaderWriterRecord {
std::unordered_set<int> readers;
std::unordered_set<int> writers;
};
using EdgeMap = std::unordered_map<std::string, ReaderWriterRecord>;
// Output: map from blob name to (reader op indices, writer op indices)
EdgeMap createEdgeMap(const GraphProto &gp) {
EdgeMap retval;
for (int i = 0; i < gp.node_size(); ++i) {
auto &node = gp.node(i);
for (const auto &in : node.input()) {
retval[in].readers.insert(i);
}
for (const auto &out : node.output()) {
retval[out].writers.insert(i);
}
}
return retval;
}
void markLiveInputNodes(const GraphProto &gp, const EdgeMap &em,
std::unordered_set<int> *live_nodes) {
std::unordered_set<std::string> initializers;
for (const auto& i : gp.initializer()) {
initializers.insert(i.name());
}
std::unordered_set<std::string> real_inputs;
for (const auto& i : gp.input()) {
if (initializers.find(i.name()) == initializers.end()) {
real_inputs.insert(i.name());
}
}
for (const auto &in : real_inputs) {
// BFS from each input
std::queue<int> frontier;
std::unordered_set<int> seen;
// TODO: assert has name
for (const int adj : em.at(in).readers) {
frontier.push(adj);
live_nodes->insert(adj);
}
while (!frontier.empty()) {
int next = frontier.front();
frontier.pop();
if (seen.find(next) != seen.end()) {
continue;
}
seen.insert(next);
for (const auto &out : gp.node(next).output()) {
for (const int adj : em.at(out).readers) {
frontier.push(adj);
live_nodes->insert(adj);
}
}
}
}
}
void markDeadCode(const GraphProto &gp, const EdgeMap &em,
std::unordered_set<int> *dead_nodes) {
}
std::tuple<GraphProto, GraphProto> partitionImmutableOps(const GraphProto &gp) {
std::cout << "creating edge map\n";
EdgeMap em = createEdgeMap(gp);
std::unordered_set<int> live_nodes;
std::cout << "mark live input nodes\n";
markLiveInputNodes(gp, em, &live_nodes);
for (const auto& x : live_nodes) {
std::cout << x << ", ";
}
std::cout << std::endl;
#if 0
std::vector<int> indices(gp.node_size());
std::iota(indices.begin(), indices.end(), 0);
std::unordered_set<int> indices_copy(indices.begin(), indices.end());
std::vector<int> a;
std::set_difference(indices_copy.begin(), indices_copy.end(), live_nodes.begin(), live_nodes.end(), std::inserter(a, a.begin()));
#endif
std::vector<int> a;
for (int i=0; i<gp.node_size(); ++i) {
if (live_nodes.find(i) == live_nodes.end()) {
a.push_back(i);
}
}
std::unordered_set<int> imm(a.begin(), a.end());
GraphProto new_proto;
new_proto.CopyFrom(gp);
new_proto.clear_node();
for (int i=0; i<gp.node_size(); ++i) {
if (imm.find(i) != imm.end()) {
continue;
}
*new_proto.add_node() = gp.node(i);
}
std::cout << "Immutable ops" << std::endl;
for (const auto& x : a) {
std::cout << x << ", ";
}
std::cout << std::endl;
for (const auto& i : live_nodes) {
auto& node = gp.node(i);
std::string out;
google::protobuf::TextFormat::PrintToString(node, &out);
std::cout << out << std::endl;
}
return std::tuple<GraphProto, GraphProto>(GraphProto(), new_proto);
}
} // namespace onnx
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
template <typename Proto>
bool ParseProtoFromBytes(Proto* proto, const char* buffer, size_t length) {
// Total bytes hard limit / warning limit are set to 1GB and 512MB
// respectively.
::google::protobuf::io::CodedInputStream coded_stream(
new google::protobuf::io::ArrayInputStream(buffer, length));
coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
return proto->ParseFromCodedStream(&coded_stream);
}
int main(int argc, char **argv) {
std::string infile;
if (argc >= 2) {
infile = argv[1];
} else {
std::cerr << "please specify model proto file" << std::endl;
return -1;
}
char *stream = (char*)malloc(1 * 1000 * 1000 * 1000);
std::fstream input(infile, std::ios::in | std::ios::binary);
input.read(stream, 1*1000*1000*1000);
onnx::ModelProto model;
ParseProtoFromBytes<onnx::ModelProto>(&model, stream, 1*1000*1000*1000);
for (const auto& x : model.graph().input()) {
std::cout << x.name() << std::endl;
}
onnx::GraphProto init, predict;
std::tie(init, predict) = partitionImmutableOps(model.graph());
onnx::ModelProto nmp;
nmp.CopyFrom(model);
*nmp.mutable_graph() = predict;
std::fstream out("culled_model.pb", std::ios::out | std::ios::trunc | std::ios::binary);
if (!nmp.SerializeToOstream(&out)) {
std::cerr << "Failed to serialize\n";
return -1;
}
free(stream);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment