Skip to content

Instantly share code, notes, and snippets.

@dboyliao
Last active July 17, 2017 07:18
Show Gist options
  • Save dboyliao/16dd630993d0a62edeb7f601ad0be253 to your computer and use it in GitHub Desktop.
Save dboyliao/16dd630993d0a62edeb7f601ad0be253 to your computer and use it in GitHub Desktop.
python
*.pb
*.pbtxt

Codes for this blog post

You can run make example to see how it works.

Note: you have to install protoc (ProtocolBuffer Compiler) before running the example code.

from python.string_fmt_example_pb2 import Box
from google.protobuf import text_format
def main():
box1 = Box()
banana = box1.items.add()
banana.name = "banana"
banana.id = 1
apple = box1.items.add()
apple.name = "apple"
apple.id = 2
# write message to file in text format
with open("box1.pbtxt", "w") as wf:
wf.write(text_format.MessageToString(box1))
box2 = Box()
knife = box2.items.add()
knife.name = "knife"
knife.id = 183
with open("box2.pbtxt", "w") as wf:
wf.write(text_format.MessageToString(box2))
# read message from text format file
new_box = Box()
with open("box1.pbtxt", "r") as rf:
try:
text_format.Merge(rf.read(), new_box)
except text_format.ParseError:
rf.seek(0)
new_box.ParseFromString(rf.read())
# merge second file into the same message object
with open("box2.pbtxt", "r") as rf:
try:
text_format.Merge(rf.read(), new_box)
except text_format.ParseError:
rf.seek(0)
new_box.ParseFromString(rf.read())
# All items are in new box now
for item in new_box.items:
print("{}: {}".format(item.id, item.name))
if __name__ == "__main__":
main()
syntax = "proto3";
package fully_nn;
message Weight {
repeated int32 shape = 1;
repeated float data = 2;
}
message Layer {
Weight weight = 1;
string name = 2;
string act_fun_name = 3;
}
message FullyConnectedNetwork {
repeated Layer layers = 1;
}
example:
@mkdir -p python;
@touch python/__init__.py
@echo "Compile protobuf file for python"
protoc --python_out=python fully_connect_nn.proto
@echo "Serializing NN" && ./write_nn.py my_nn.pb
@echo
@echo "Deserializing NN" && ./read_nn.py my_nn.pb
txt-example:
@mkdir -p python
@touch python/__init__.py
@echo "Compile protobuf file for text format example (python)"
protoc --python_out=python ./string_fmt_example.proto
@echo "Running exmaple code: example_str_fmt.py"
python3 example_str_fmt.py
#!/usr/bin/env python3
import argparse
from python.fully_connect_nn_pb2 import FullyConnectedNetwork
import numpy as np
def main(in_fname):
"""
Main reading function
"""
pb_nn = FullyConnectedNetwork()
print("Reading NN")
with open(in_fname, "rb") as rf:
pb_nn.ParseFromString(rf.read())
for layer in pb_nn.layers:
data = layer.weight.data
shape = layer.weight.shape
weight = np.array(data).reshape(shape)
print("{}:\n{}".format(layer.name, weight))
print(" {}".format(layer.act_fun_name))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("in_fname", metavar="IN_FILE",
help="input protobuf serialized file")
args = vars(parser.parse_args())
main(**args)
syntax = "proto3";
package string_fmt_example;
message Item {
string name = 1;
int32 id = 2;
}
message Box {
repeated Item items = 2;
}
#!/usr/bin/env python3
import argparse
from python.fully_connect_nn_pb2 import FullyConnectedNetwork
import numpy as np
import random
def main(out_fname, nn_struct):
"""
Main function
"""
weights = []
for in_shape, out_shape in zip(nn_struct[:-1], nn_struct[1:]):
weight = np.random.randn(in_shape, out_shape)
weights.append(weight)
pb_nn = FullyConnectedNetwork()
print("Writing NN")
act_funcs = ["sigmoid", "relu"]
for i, weight in enumerate(weights):
pb_layer = pb_nn.layers.add()
pb_layer.weight.shape.extend(weight.shape)
pb_layer.weight.data.extend(weight.flatten())
pb_layer.name = "layer_{}".format(i)
pb_layer.act_fun_name = random.sample(act_funcs, 1)[0]
print("{}:\n{}".format(pb_layer.name, weight))
print("{}".format(pb_layer.act_fun_name))
with open(out_fname, "wb") as wf:
wf.write(pb_nn.SerializeToString())
def _struct_type(argstr):
struct = map(int, argstr.strip().split(","))
return tuple(struct)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--network-strucutre", dest="nn_struct",
metavar="INT,INT,...", type=_struct_type,
help="structure of the neural network (default: 3,3,3)",
default=(3, 3, 3))
parser.add_argument("out_fname", metavar="OUT_FILE",
help="output serialized file")
args = vars(parser.parse_args())
main(**args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment