Created
March 7, 2021 05:43
-
-
Save guru-florida/372d43ae336d34eb601bbfee68678e8d to your computer and use it in GitHub Desktop.
Model Training URDF parameters in Ros2 with Gazebo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# | |
# model-training.py | |
# Use HyperOpt and Gazebo to train parameters of your robot. Perfoms a number | |
# of episodic simulations and optimizing the inputs to a xacro-enabled urdf file. | |
# | |
# Copyright 2020 FlyingEinstein.com | |
# Author: Colin F. MacKenzie | |
# | |
# Features: | |
# * will call your launch file via the Ros2 launch API. | |
# * will attempt to restart simulation by deleting and respawning robot without | |
# restarting Gazebo or launch file. | |
# * Sometimes ROS nodes crash or Gazebo freezes on sim restart. In this case, | |
# the whole launch file setup will be killed and restarted. | |
# * Can write episode results to log file training.csv | |
# * Can write current episode values to training.ods. If you record the training | |
# using OBS screen recorder then you can add the text overlay with this file | |
# and OBS will update the on-screen display as episodes play out. | |
# | |
# Requirements: | |
# * This was used for my project and I haven't gotten around to generalizing | |
# this code yet. So expect you may need to get intimate with how this code | |
# works. | |
# * review imu_callback and odom_callback which establishes a health value, | |
# or write a node to emit a "health value" and just subscribe to that. | |
# * Review episode_async where it sets up the episode and determines when it's | |
# finished. | |
# * determine what variables in your URDF you will optimize/learn. This can be | |
# any value in a URDF including Gazebo parameters. | |
# * Convert your urdf file to xacro if you haven't already. Replace the values | |
# you want to optimize with xacro variables. | |
# * Setup the config variable in the run() method. You will want to look at the | |
# Hyperopt library to see how these configs work. Each episode hyperopt will | |
# choose new values for your variables based on previous health scores and | |
# the episode_async will reparse your xacro urdf with the new set of variables | |
# and respawn the sim. | |
# | |
# Example arguments for training URDF parameters for the LSS humanoid model: | |
# -package lss_humanoid -xacro urdf/lss_humanoid.xacro.urdf -entity humanoid -z 0.3 -episodes 100 | |
# | |
# Portions of this file were based on spawn_entity.py from | |
# Open Source Robotics Foundation and John Hsu, Dave Coleman | |
import argparse | |
import math | |
import os | |
import sys | |
import time | |
import asyncio | |
import threading | |
import psutil | |
from typing import List | |
from typing import Text | |
from typing import Tuple | |
from collections import OrderedDict | |
# Ros2 node imports | |
import rclpy | |
from launch.event_handlers import OnProcessIO | |
from rclpy.node import Node | |
from rclpy.qos import QoSDurabilityPolicy | |
from sensor_msgs.msg import Imu | |
from geometry_msgs.msg import Vector3 | |
from geometry_msgs.msg import Pose | |
from nav_msgs.msg import Odometry | |
from gazebo_msgs.srv import SpawnEntity, DeleteEntity | |
# Ros2 launch API | |
import launch | |
import xacro | |
from ros2launch.api import get_share_file_path_from_package | |
from ament_index_python.packages import PackageNotFoundError, get_package_share_directory | |
# Hyperopt Optimization API | |
from hyperopt import hp, fmin, tpe, STATUS_OK, STATUS_FAIL, Trials | |
from hyperopt.mongoexp import MongoTrials | |
class EntityException(RuntimeError): | |
"""Raised when an entity service request error has occured""" | |
pass | |
class EntityOperationFailed(EntityException): | |
"""Raised when an entity service request has failed""" | |
pass | |
class EntityTimeout(EntityException): | |
"""Raised when an entity service request has timed out""" | |
pass | |
class TrainModelNode(Node): | |
simulation_task = None | |
event_loop = None | |
launch_service = None | |
optimize_thread = None | |
acc = None | |
acc_mix = 0.01 | |
attempts = 0 | |
active = False | |
fallen = False | |
distance = 0 | |
direction = 0 | |
avgAngularVelocity = None | |
startTs = None | |
currentTs = None | |
package_dir = None | |
xacro_urdf = None | |
tasks = {} | |
def __init__(self, args): | |
super().__init__('train_model') | |
parser = argparse.ArgumentParser( | |
description='Spawn an entity in gazebo. Gazebo must be started with gazebo_ros_init,\ | |
gazebo_ros_factory and gazebo_ros_state for all functionalities to work') | |
parser.add_argument('-package', required=True, type=str, metavar='PKG_NAME', | |
help='The package containing the model we will train') | |
parser.add_argument('-xacro', required=True, type=str, metavar='FILE_NAME', | |
help='The xacro file to substitute with parameters') | |
parser.add_argument('-entity', required=True, type=str, metavar='ENTITY_NAME', | |
help='Name of entity to spawn') | |
parser.add_argument('-reference_frame', type=str, default='', | |
help='Name of the model/body where initial pose is defined.\ | |
If left empty or specified as "world", gazebo world frame is used') | |
parser.add_argument('-gazebo_namespace', type=str, default='', | |
help='ROS namespace of gazebo offered ROS interfaces. \ | |
Default is without any namespace') | |
parser.add_argument('-robot_namespace', type=str, default='', | |
help='change ROS namespace of gazebo-plugins') | |
parser.add_argument('-timeout', type=float, default=30.0, | |
help='Number of seconds to wait for the spawn and delete services to \ | |
become available') | |
parser.add_argument('-wait', type=str, metavar='ENTITY_NAME', | |
help='Wait for entity to exist') | |
parser.add_argument('-spawn_service_timeout', type=float, metavar='TIMEOUT', | |
default=15.0, help='Spawn service wait timeout in seconds') | |
parser.add_argument('-episodes', type=int, default=100, | |
help='max number of training iterations') | |
parser.add_argument('-mongodb', type=str, | |
help='store trials in a Mongo database') | |
parser.add_argument('-expid', type=int, | |
help='optional expirement ID (use with mongo and hyperopt)') | |
parser.add_argument('-x', type=float, default=0, | |
help='x component of initial position, meters') | |
parser.add_argument('-y', type=float, default=0, | |
help='y component of initial position, meters') | |
parser.add_argument('-z', type=float, default=0, | |
help='z component of initial position, meters') | |
parser.add_argument('-R', type=float, default=0, | |
help='roll angle of initial orientation, radians') | |
parser.add_argument('-P', type=float, default=0, | |
help='pitch angle of initial orientation, radians') | |
parser.add_argument('-Y', type=float, default=0, | |
help='yaw angle of initial orientation, radians') | |
self.args = parser.parse_args(args[1:]) | |
self.acc = Vector3() | |
# get the share location for the package containing the model | |
# we will be training | |
try: | |
self.package_dir = get_package_share_directory(self.args.package) | |
xacro_urdf_file = os.path.join( | |
self.package_dir, | |
self.args.xacro | |
) | |
self.get_logger().info(f'using xacro urdf file at {xacro_urdf_file}') | |
except PackageNotFoundError as e: | |
self.get_logger().error(f'cannot find share folder for package {self.args.package}') | |
exit(-2) | |
# URDF file in the form of a xacro file | |
# xacro is required since we are training parameters that are | |
# replaced during xacro parsing. | |
try: | |
self.xacro_urdf = open(xacro_urdf_file) | |
except OSError as e: | |
self.get_logger().error(f'cannot open xacro file: {xacro_urdf_file}') | |
exit(-2) | |
# subscribe to imu data so we know when the robot has fallen | |
sensor_qos = rclpy.qos.QoSPresetProfiles.get_from_short_key('sensor_data') | |
self.imu_data = self.create_subscription( | |
Imu, | |
'imu/data', | |
self.imu_callback, | |
sensor_qos) | |
self.imu_data = self.create_subscription( | |
Odometry, | |
'odom', | |
self.odom_callback, | |
sensor_qos) | |
def imu_callback(self, msg): | |
self.currentTs = msg.header.stamp | |
if self.startTs is None: | |
self.startTs = self.currentTs | |
prev_mix = 1 - self.acc_mix | |
self.acc.x = self.acc.x * prev_mix + msg.linear_acceleration.x * self.acc_mix | |
self.acc.y = self.acc.y * prev_mix + msg.linear_acceleration.y * self.acc_mix | |
self.acc.z = self.acc.z * prev_mix + msg.linear_acceleration.z * self.acc_mix | |
if 9 < abs(self.acc.z) < 9.85: | |
self.fallen = True | |
else: | |
self.fallen = False | |
angVel = math.sqrt( | |
msg.angular_velocity.x * msg.angular_velocity.x + | |
msg.angular_velocity.y * msg.angular_velocity.y + | |
msg.angular_velocity.z * msg.angular_velocity.z | |
) | |
self.avgAngularVelocity = self.avgAngularVelocity * prev_mix + angVel * self.acc_mix \ | |
if self.avgAngularVelocity \ | |
else angVel | |
# self.get_logger().info(' %s A:%2.4f,%2.4f,%2.4f' % | |
# ( 'fallen' if self.fallen else 'standing', self.acc.x, self.acc.y, self.acc.z)) | |
def odom_callback(self, msg): | |
x = msg.pose.pose.position.x | |
y = msg.pose.pose.position.y | |
if self.currentTs.sec - self.startTs.sec > 2: | |
self.distance = round(math.sqrt(x * x + y * y), 2) | |
self.direction = round(math.atan2(y, x), 2) | |
else: | |
self.distance = 0 | |
self.direction = 0 | |
# self.get_logger().info(' odom: %2.4f @ %2.4f' % (self.distance, self.direction)) | |
async def spawn_entity(self, entity_xml, initial_pose, timeout=10.0): | |
# originally from gazebo_ros_pkgs/gazebo_ros/scripts/spawn_entity.py | |
# but modified for asyncio operation with timeouts | |
if timeout < 0: | |
self.get_logger().error('spawn_entity timeout must be greater than zero') | |
return False | |
self.get_logger().debug( | |
'Waiting for service %s/spawn_entity, timeout = %.f' % ( | |
self.args.gazebo_namespace, timeout)) | |
self.get_logger().debug('Waiting for service %s/spawn_entity' % self.args.gazebo_namespace) | |
client = self.create_client(SpawnEntity, '%s/spawn_entity' % self.args.gazebo_namespace) | |
if client.wait_for_service(timeout_sec=timeout): | |
req = SpawnEntity.Request() | |
req.name = self.args.entity | |
req.xml = str(entity_xml, 'utf-8') | |
req.robot_namespace = self.args.robot_namespace | |
req.initial_pose = initial_pose | |
req.reference_frame = self.args.reference_frame | |
self.get_logger().debug('Calling service %s/spawn_entity' % self.args.gazebo_namespace) | |
try: | |
srv_call = await asyncio.wait_for(client.call_async(req), timeout=timeout) | |
if not srv_call.success: | |
raise EntityOperationFailed('spawn ' + req.name) | |
except asyncio.TimeoutError: | |
raise EntityTimeout('spawn ' + req.name + ' timeout') | |
else: | |
self.get_logger().error( | |
'Service %s/spawn_entity unavailable. Was Gazebo started with GazeboRosFactory?' | |
% self.args.gazebo_namespace) | |
raise EntityTimeout('spawn_entity service') | |
async def delete_entity(self, timeout=10.0, ignore_failure=False): | |
# originally from gazebo_ros_pkgs/gazebo_ros/scripts/spawn_entity.py | |
# but modified for asyncio operation with timeouts | |
self.get_logger().debug('Deleting entity [{}]'.format(self.args.entity)) | |
client = self.create_client( | |
DeleteEntity, '%s/delete_entity' % self.args.gazebo_namespace) | |
if client.wait_for_service(timeout_sec=timeout): | |
req = DeleteEntity.Request() | |
req.name = self.args.entity | |
self.get_logger().debug( | |
'Calling service %s/delete_entity' % self.args.gazebo_namespace) | |
try: | |
#srv_call = await asyncio.wait_for(client.call_async(req), timeout=timeout) | |
srv_call_result = await asyncio.wait_for(client.call_async(req), timeout=timeout) | |
if not srv_call_result.success and not ignore_failure: | |
raise EntityOperationFailed('delete ' + req.name) | |
except asyncio.TimeoutError: | |
raise EntityTimeout('delete ' + req.name + ' timeout') | |
else: | |
self.get_logger().error( | |
'Service %s/delete_entity unavailable. ' + | |
'Was Gazebo started with GazeboRosFactory?' % self.args.gazebo_namespace) | |
if not ignore_failure: | |
raise EntityTimeout('delete_entity service') | |
# unfortunately the gzserver process often does not terminate when a launch is | |
# terminated so we will ensure there is no existing gzserver process before | |
# launching simulation again. | |
@staticmethod | |
def kill_gzserver(): | |
for proc in psutil.process_iter(): | |
# check whether the process name matches | |
if proc.name() == 'gzserver': | |
proc.kill() | |
# borrowed from launch_service.py in the Ros2 Launch API | |
@staticmethod | |
def parse_launch_arguments(launch_arguments: List[Text]) -> List[Tuple[Text, Text]]: | |
"""Parse the given launch arguments from the command line, into list of tuples for launch.""" | |
parsed_launch_arguments = OrderedDict() # type: ignore | |
for argument in launch_arguments: | |
count = argument.count(':=') | |
if count == 0 or argument.startswith(':=') or (count == 1 and argument.endswith(':=')): | |
raise RuntimeError( | |
"malformed launch argument '{}', expected format '<name>:=<value>'" | |
.format(argument)) | |
name, value = argument.split(':=', maxsplit=1) | |
parsed_launch_arguments[name] = value # last one wins is intentional | |
return parsed_launch_arguments.items() | |
# borrowed from launch_service.py in the Ros2 Launch API | |
def launch_a_launch_file(self, *, launch_file_path, launch_file_arguments, debug=False): | |
# want to set gzserver instance? set environment, then add env=node_env in Node() (wait, where?) | |
#node_env = os.environ.copy() | |
#node_env["PYTHONUNBUFFERED"] = "1" # dont buffer output | |
"""Launch a given launch file (by path) and pass it the given launch file arguments.""" | |
launch_service = launch.LaunchService(argv=launch_file_arguments, debug=debug) | |
parsed_launch_arguments = self.parse_launch_arguments(launch_file_arguments) | |
# Include the user provided launch file using IncludeLaunchDescription so that the | |
# location of the current launch file is set. | |
launch_description = launch.LaunchDescription([ | |
launch.actions.IncludeLaunchDescription( | |
launch.launch_description_sources.AnyLaunchDescriptionSource( | |
launch_file_path | |
), | |
launch_arguments=parsed_launch_arguments, | |
), | |
launch.actions.RegisterEventHandler( | |
OnProcessIO( | |
on_stdout=lambda info: print('>>>'+str(info.text)+'<<<'), | |
on_stderr=lambda info: print('***'+str(info.text)+'***') | |
) | |
) | |
]) | |
launch_service.include_launch_description(launch_description) | |
return launch_service | |
async def launch_simulation(self, timeout=10.0): | |
path = get_share_file_path_from_package( | |
package_name=self.args.package, | |
file_name='simulation2.launch.py') | |
self.launch_service = self.launch_a_launch_file( | |
launch_file_path=path, | |
launch_file_arguments=["__log_level:=error"], | |
debug=False | |
) | |
# ensure gzserver isnt running | |
self.kill_gzserver() | |
self.simulation_task = asyncio.create_task(self.launch_service.run_async( | |
shutdown_when_idle=True | |
)) | |
return self.simulation_task | |
async def kill_simulation(self, timeout=10.0, msg:str = None): | |
if self.simulation_task: | |
print('cancelling simulation\n') | |
await self.launch_service.shutdown() | |
#while(self.simulation_task.done() | |
self.simulation_task = None | |
#res = await self.simulation_task.result() | |
print('canceled simulation\n') | |
await asyncio.sleep(2.0) | |
def episode(self, config): | |
# This is a trampoline method to run the episode on the async event loop | |
fut = asyncio.run_coroutine_threadsafe(self.episode_async(config), self.event_loop) | |
# block and wait for future to return a result | |
return fut.result() | |
async def episode_async(self, config): | |
done = False | |
# respawn entity | |
self.currentTs = self.startTs = None | |
self.fallen = False | |
self.active = True | |
self.distance = 0 | |
self.direction = 0 | |
self.acc = Vector3() | |
self.write_current_config(config) | |
# convert all config mappings into strings | |
mappings = {k: str(v) for (k, v) in config.items()} | |
#print('config: ', mappings) | |
# convert the xacro into final file | |
self.xacro_urdf.seek(0, 0) | |
doc = xacro.parse(self.xacro_urdf) | |
xacro.process_doc(doc, mappings=mappings) | |
entity_xml = doc.toxml('utf-8') | |
# Form requested Pose from arguments | |
initial_pose = Pose() | |
initial_pose.position.x = float(self.args.x) | |
initial_pose.position.y = float(self.args.y) | |
initial_pose.position.z = float(self.args.z) | |
q = quaternion_from_euler(self.args.R, self.args.P, self.args.Y) | |
initial_pose.orientation.w = q[0] | |
initial_pose.orientation.x = q[1] | |
initial_pose.orientation.y = q[2] | |
initial_pose.orientation.z = q[3] | |
if self.simulation_task and self.simulation_task.done(): | |
print("simulation is apparently DONE") | |
if not self.simulation_task or self.simulation_task.done(): | |
# restart simulation | |
print("starting simulation") | |
try: | |
await self.launch_simulation() | |
except Exception as e: | |
print("exception launching simulation: ", e) | |
return { | |
'status': STATUS_FAIL | |
} | |
await asyncio.sleep(5.0) | |
try: | |
await self.spawn_entity(entity_xml, initial_pose, 30) | |
except EntityException as e: | |
self.get_logger().error('Spawn service failed: %s' % e) | |
await self.kill_simulation() | |
return { | |
'status': STATUS_FAIL | |
} | |
# spin the simulation | |
while not done: | |
duration = self.currentTs.sec - self.startTs.sec if self.currentTs else 0 | |
if self.fallen or self.distance > 1 or duration > 30: | |
if self.currentTs: | |
score = self.distance * self.avgAngularVelocity | |
if self.fallen: | |
score = score * 10 | |
result = { | |
'startTs': self.startTs.sec, | |
'currentTs': self.currentTs.sec, | |
'duration': self.currentTs.sec - self.startTs.sec, | |
'distance': self.distance, | |
'direction': self.direction, | |
'angVelocity': self.avgAngularVelocity, | |
'fell': self.fallen, | |
'score': score | |
} | |
self.write_episode_result(config, result) | |
# remove the entity | |
try: | |
await self.delete_entity() | |
time.sleep(1.0) | |
except EntityException as e: | |
self.get_logger().error('Delete entity failed: %s' % e) | |
# return data | |
return { | |
'loss': score, | |
'status': STATUS_OK, | |
# extra fields | |
'startTs': self.startTs.sec, | |
'currentTs': self.currentTs.sec, | |
'duration': self.currentTs.sec - self.startTs.sec, | |
'distance': self.distance, | |
'direction': self.direction, | |
'angVelocity': self.avgAngularVelocity, | |
'fell': self.fallen, | |
} | |
#rclpy.spin_once(self) | |
await asyncio.sleep(0.05) | |
# exited loop without before end condition | |
return { | |
'status': STATUS_FAIL | |
} | |
@staticmethod | |
def write_current_config(config): | |
columns = ['mu1', 'mu2', 'kp', 'kd'] | |
values = "\n".join([(k + ": " + "{:.2f}".format(v)) for (k, v) in config.items() if k in columns]) | |
cfile = open("current.txt", "w") | |
cfile.write(values) | |
cfile.close() | |
@staticmethod | |
def write_episode_result(config, result): | |
training_file = "training.csv" | |
columns = ['startTs', 'duration', 'distance', 'direction', 'angVelocity', 'fell', 'score', 'mu1', 'mu2', 'kp', 'kd'] | |
combined = {**config, **result} | |
values = [str(combined[k]) for k in columns] | |
write_header = not os.path.exists(training_file) | |
# append config values to file | |
cfile = open(training_file, "a") | |
if write_header: | |
cfile.write(", ".join(columns) + "\n") | |
cfile.write(", ".join(values) + "\n") | |
cfile.close() | |
def optimize(self, config): | |
# perform optimization episodes with the given config | |
self.get_logger().info(f'Training {self.args.entity} over {self.args.episodes} episodes') | |
# Specify the search space and maximize score | |
if self.args.mongodb: | |
trials = MongoTrials(self.args.mongodb, exp_key=self.args.expid) | |
else: | |
trials = Trials() | |
best = fmin( | |
self.episode, | |
space=config, | |
algo=tpe.suggest, | |
max_evals=self.args.episodes, | |
trials=trials | |
) | |
print(best) | |
# shutting down Ros2 will shut down the program | |
rclpy.shutdown() | |
def run(self): | |
# our parameters to train | |
#config = { | |
# 'target': 'gazebo', | |
# 'kp': hp.uniform('kp', 10000, 40000), | |
# 'kd': hp.uniform('kd', 100, 1000), | |
# 'mu1': hp.uniform('mu1', 200, 500), | |
# 'mu2': hp.uniform('mu2', 200, 1000) | |
#} | |
config = { | |
'target': 'gazebo', | |
'kp': hp.uniform('kp', 1000, 40000), | |
'kd': hp.uniform('kd', 1, 1000), | |
'mu1': hp.uniform('mu1', 0.01, 1000), | |
'mu2': hp.uniform('mu2', 0.01, 2000) | |
} | |
self.event_loop = asyncio.get_event_loop() | |
# Optimizer is not asyncio compliant so we will run in a thread and have | |
# it send events back to the main event loop | |
self.optimize_thread = threading.Thread(target=self.optimize, args=(config,)) | |
# this task will run forever to process ROS node events | |
async def rosloop(): | |
while rclpy.ok(): | |
rclpy.spin_once(self, timeout_sec=0) | |
await asyncio.sleep(0.01) | |
# perform optimization startup sequence here but from within the event loop | |
async def kickstart(): | |
# first ensure the entity doesnt exist already | |
#await self.delete_entity(timeout=10.0, ignore_failure=True) | |
# now begin the optimization thread | |
self.optimize_thread.start() | |
# perform main event loop processing | |
try: | |
asyncio.ensure_future(rosloop()) | |
asyncio.ensure_future(kickstart()) | |
self.event_loop.run_forever() | |
except KeyboardInterrupt: | |
pass | |
finally: | |
self.optimize_thread.join(timeout=12.0) | |
self.event_loop.close() | |
self.event_loop = None | |
# borrowed from gazebo_ros_pkgs/gazebo_ros/scripts/spawn_entity.py | |
def quaternion_from_euler(roll, pitch, yaw): | |
cy = math.cos(yaw * 0.5) | |
sy = math.sin(yaw * 0.5) | |
cp = math.cos(pitch * 0.5) | |
sp = math.sin(pitch * 0.5) | |
cr = math.cos(roll * 0.5) | |
sr = math.sin(roll * 0.5) | |
q = [0] * 4 | |
q[0] = cy * cp * cr + sy * sp * sr | |
q[1] = cy * cp * sr - sy * sp * cr | |
q[2] = sy * cp * sr + cy * sp * cr | |
q[3] = sy * cp * cr - cy * sp * sr | |
return q | |
def main(args=sys.argv): | |
rclpy.init(args=args) | |
args_without_ros = rclpy.utilities.remove_ros_args(args) | |
train_model_node = TrainModelNode(args_without_ros) | |
train_model_node.run() | |
rclpy.shutdown() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment