Codes for this blog post
You can run make example
to see how it works.
Note: you have to install protoc
(ProtocolBuffer Compiler) before running the example code.
python | |
*.pb | |
*.pbtxt |
from python.string_fmt_example_pb2 import Box | |
from google.protobuf import text_format | |
def main(): | |
box1 = Box() | |
banana = box1.items.add() | |
banana.name = "banana" | |
banana.id = 1 | |
apple = box1.items.add() | |
apple.name = "apple" | |
apple.id = 2 | |
# write message to file in text format | |
with open("box1.pbtxt", "w") as wf: | |
wf.write(text_format.MessageToString(box1)) | |
box2 = Box() | |
knife = box2.items.add() | |
knife.name = "knife" | |
knife.id = 183 | |
with open("box2.pbtxt", "w") as wf: | |
wf.write(text_format.MessageToString(box2)) | |
# read message from text format file | |
new_box = Box() | |
with open("box1.pbtxt", "r") as rf: | |
try: | |
text_format.Merge(rf.read(), new_box) | |
except text_format.ParseError: | |
rf.seek(0) | |
new_box.ParseFromString(rf.read()) | |
# merge second file into the same message object | |
with open("box2.pbtxt", "r") as rf: | |
try: | |
text_format.Merge(rf.read(), new_box) | |
except text_format.ParseError: | |
rf.seek(0) | |
new_box.ParseFromString(rf.read()) | |
# All items are in new box now | |
for item in new_box.items: | |
print("{}: {}".format(item.id, item.name)) | |
if __name__ == "__main__": | |
main() |
syntax = "proto3"; | |
package fully_nn; | |
message Weight { | |
repeated int32 shape = 1; | |
repeated float data = 2; | |
} | |
message Layer { | |
Weight weight = 1; | |
string name = 2; | |
string act_fun_name = 3; | |
} | |
message FullyConnectedNetwork { | |
repeated Layer layers = 1; | |
} |
example: | |
@mkdir -p python; | |
@touch python/__init__.py | |
@echo "Compile protobuf file for python" | |
protoc --python_out=python fully_connect_nn.proto | |
@echo "Serializing NN" && ./write_nn.py my_nn.pb | |
@echo | |
@echo "Deserializing NN" && ./read_nn.py my_nn.pb | |
txt-example: | |
@mkdir -p python | |
@touch python/__init__.py | |
@echo "Compile protobuf file for text format example (python)" | |
protoc --python_out=python ./string_fmt_example.proto | |
@echo "Running exmaple code: example_str_fmt.py" | |
python3 example_str_fmt.py |
#!/usr/bin/env python3 | |
import argparse | |
from python.fully_connect_nn_pb2 import FullyConnectedNetwork | |
import numpy as np | |
def main(in_fname): | |
""" | |
Main reading function | |
""" | |
pb_nn = FullyConnectedNetwork() | |
print("Reading NN") | |
with open(in_fname, "rb") as rf: | |
pb_nn.ParseFromString(rf.read()) | |
for layer in pb_nn.layers: | |
data = layer.weight.data | |
shape = layer.weight.shape | |
weight = np.array(data).reshape(shape) | |
print("{}:\n{}".format(layer.name, weight)) | |
print(" {}".format(layer.act_fun_name)) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("in_fname", metavar="IN_FILE", | |
help="input protobuf serialized file") | |
args = vars(parser.parse_args()) | |
main(**args) |
syntax = "proto3"; | |
package string_fmt_example; | |
message Item { | |
string name = 1; | |
int32 id = 2; | |
} | |
message Box { | |
repeated Item items = 2; | |
} |
#!/usr/bin/env python3 | |
import argparse | |
from python.fully_connect_nn_pb2 import FullyConnectedNetwork | |
import numpy as np | |
import random | |
def main(out_fname, nn_struct): | |
""" | |
Main function | |
""" | |
weights = [] | |
for in_shape, out_shape in zip(nn_struct[:-1], nn_struct[1:]): | |
weight = np.random.randn(in_shape, out_shape) | |
weights.append(weight) | |
pb_nn = FullyConnectedNetwork() | |
print("Writing NN") | |
act_funcs = ["sigmoid", "relu"] | |
for i, weight in enumerate(weights): | |
pb_layer = pb_nn.layers.add() | |
pb_layer.weight.shape.extend(weight.shape) | |
pb_layer.weight.data.extend(weight.flatten()) | |
pb_layer.name = "layer_{}".format(i) | |
pb_layer.act_fun_name = random.sample(act_funcs, 1)[0] | |
print("{}:\n{}".format(pb_layer.name, weight)) | |
print("{}".format(pb_layer.act_fun_name)) | |
with open(out_fname, "wb") as wf: | |
wf.write(pb_nn.SerializeToString()) | |
def _struct_type(argstr): | |
struct = map(int, argstr.strip().split(",")) | |
return tuple(struct) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-s", "--network-strucutre", dest="nn_struct", | |
metavar="INT,INT,...", type=_struct_type, | |
help="structure of the neural network (default: 3,3,3)", | |
default=(3, 3, 3)) | |
parser.add_argument("out_fname", metavar="OUT_FILE", | |
help="output serialized file") | |
args = vars(parser.parse_args()) | |
main(**args) |