Last active
March 6, 2017 04:51
-
-
Save spicavigo/3afde12b5b48b0de8e7294f1d58934c6 to your computer and use it in GitHub Desktop.
Run the ROS Node using CNN model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import rospy | |
from steering_node import SteeringNode | |
import argparse | |
import json | |
from scipy import misc | |
from keras.optimizers import SGD | |
from keras.models import model_from_json | |
def process(model, img): | |
misc.imsave('test.png', img) | |
img = misc.imresize(img[320:, :, :], (50, 200, 3)) | |
steering = model.predict(img[None, :, :, :])[0][0] | |
print steering | |
return steering | |
def get_model(model_file): | |
with open(model_file, 'r') as jfile: | |
model = model_from_json(json.load(jfile)) | |
sgd = SGD(lr=0.0001, decay=1e-6, momentum=0.9, nesterov=True) | |
model.compile(sgd, "mse") | |
weights_file = model_file.replace('json', 'keras') | |
model.load_weights(weights_file) | |
return model | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Model Runner') | |
parser.add_argument('model', type=str, help='Path to model definition json. \ | |
Model weights should be on the same path.') | |
args = parser.parse_args() | |
node = SteeringNode(lambda: get_model(args.model), process) | |
rospy.spin() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello,
Will
model.compile()
get called every spin? I am currently working on integrating ROS and Keras so I wonder if this will affect the performance.