Last active
December 18, 2018 06:02
-
-
Save qfgaohao/597cb3ccc6ebd1b544a0a14f90fce1dc to your computer and use it in GitHub Desktop.
demonstrates how to train a model, init weights from another source (transfer learning), save models to pb and pbtxt files.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from caffe2.python import ( | |
brew, | |
model_helper, | |
optimizer, | |
workspace, | |
utils, | |
) | |
from caffe2.proto import caffe2_pb2 | |
from caffe2.python.predictor import mobile_exporter | |
def gen_data(batch_size=2): | |
x = np.random.randn(batch_size, 2) | |
y = np.dot(x, np.array([[1.0], [2.0]])) + 0.5 | |
return x, y | |
x, y = gen_data() | |
print("-------------Training data sample--------------") | |
print("x", x) | |
print('y', y) | |
print('\n\n') | |
def create_net(model): | |
return brew.fc(train_model, 'X', 'y_pred', dim_in=2, dim_out=1) | |
workspace.ResetWorkspace() | |
train_model = model_helper.ModelHelper('regression model') | |
y_pred = create_net(train_model) | |
dist = train_model.SquaredL2Distance(['Y_gt', y_pred], "dist") | |
loss = train_model.AveragedLoss(dist, "loss") | |
# Add the gradient operators and setup the SGD algorithm | |
train_model.AddGradientOperators([loss]) | |
optimizer.build_sgd(train_model, base_learning_rate=0.01) | |
x, y = gen_data() | |
# Prime the workspace with some data | |
workspace.FeedBlob("Y_gt",y.astype(np.float32)) | |
workspace.FeedBlob("X",x.astype(np.float32)) | |
# Run the init net to prepare the workspace then create the net | |
workspace.RunNetOnce(train_model.param_init_net) | |
workspace.CreateNet(train_model.net) | |
# Train the model or inject the weigths from somewhere | |
# Inject our desired initial weights and bias | |
print("you can just inject weights from somewhere without training.") | |
workspace.FeedBlob("y_pred_w",np.random.randn(1, 2).astype(np.float32)) | |
workspace.FeedBlob("y_pred_b",np.array([0.]).astype(np.float32)) | |
for i in range(500): | |
x, y = gen_data() | |
workspace.FeedBlob('Y_gt', y.astype(np.float32)) | |
workspace.FeedBlob('X', x.astype(np.float32)) | |
workspace.RunNet(train_model.net) | |
# create test net | |
test_model= model_helper.ModelHelper(name="test_net", init_params=False) | |
create_net(test_model) | |
workspace.RunNetOnce(test_model.param_init_net) | |
workspace.CreateNet(test_model.net, overwrite=True) | |
# Prime the workspace with some data | |
data = np.zeros((1,2)).astype('float32') | |
workspace.FeedBlob("data", data) | |
workspace.RunNet(test_model.net, 1) | |
# test, optional | |
workspace.FeedBlob('data', np.random.randn(5, 2).astype(np.float32)) | |
workspace.RunNet(test_model.net, 1) | |
print("Testing results:\n") | |
print(workspace.FetchBlob('y_pred')) | |
# save the model | |
print("Save the model to init_net.pb and predict_net.pb") | |
init_net, predict_net = mobile_exporter.Export(workspace, test_model.net, test_model.params) | |
with open("init_net.pb", 'wb') as f: | |
f.write(init_net.SerializeToString()) | |
with open("predict_net.pb", 'wb') as f: | |
f.write(predict_net.SerializeToString()) | |
print("Save the mode to init_net.pbtxt and predict_net.pbtxt") | |
with open('init_net.pbtxt', 'w') as f: | |
f.write(str(init_net)) | |
with open('predict_net.pbtxt', 'w') as f: | |
f.write(str(predict_net)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi qfgaohao,
which adaptations would I have to do in order to load the model from .pbtxt files? I use the following loading net method: