Skip to content

Instantly share code, notes, and snippets.

@sjchoi86
Created January 12, 2025 03:05
Show Gist options
  • Save sjchoi86/f6950c3df4b16ba120972e14aea061bb to your computer and use it in GitHub Desktop.
Save sjchoi86/f6950c3df4b16ba120972e14aea061bb to your computer and use it in GitHub Desktop.
Kinematic chain using NumPy capable of doing batched FK.
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from networkx.algorithms.traversal.depth_first_search import dfs_edges
"""
sys.path.append('../../package/kinematics_helper/') # for 'transforms'
"""
from transforms import (
rodrigues,
rodrigues_batch,
quat2r,
pr2t,
t2p,
t2r,
)
class NumpyChainClass(object):
"""
Kinematic chain using Numpy
"""
def __init__(self,name='Numpy Chain'):
"""
Initialize Chain Class
Parameters:
name: str - Name of the chain.
"""
self.chain = None
self.init_chain(name=name) # initialize chain
def init_chain(self,name=None):
"""
Initialize chain
Parameters:
name: str - Name of the chain.
"""
# Set name
self.name = name
# Clear and instantiate chain
if self.chain is not None:
self.chain.clear()
self.chain = nx.DiGraph(name=name)
# Clear body and joint names
self.body_names = []
self.joint_names = []
def get_n_body(self):
"""
Get the number of bodies.
Returns:
int - Number of bodies in the chain.
"""
return self.chain.number_of_nodes()
def get_body_idx(self,body_name):
"""
Get the index of a body.
Parameters:
body_name: str - Name of the body.
Returns:
int - Index of the body.
"""
body_idx = self.body_names.index(body_name)
return body_idx
def get_body_idxs(self,body_names):
"""
Get the indices of bodies.
Parameters:
body_names: list of str - Names of the bodies.
Returns:
list of int - Indices of the bodies.
"""
body_idxs = [self.body_names.index(name)
if name in self.body_names else None
for name in body_names]
return body_idxs
def get_root_body_idx(self):
"""
Get the root body index.
Returns:
int - Index of the root body.
"""
for node in self.chain.nodes:
if self.chain.in_degree(node) == 0: # 루트 노드는 in-degree가 0인 노드
return list(self.chain.nodes).index(node)
raise ValueError("No root node found in the chain.")
def get_joint_idx(self,joint_name):
"""
Get the index of a joint.
Parameters:
joint_name: str - Name of the joint.
Returns:
int - Index of the joint.
"""
joint_idx = self.joint_names.index(joint_name)
return joint_idx
def get_body(self,body_idx):
"""
Get body information.
Parameters:
body_idx: int - Index of the body.
Returns:
dict - Information of the body.
"""
body = self.chain.nodes[body_idx]
return body
def get_p_body(self,body_name=None,body_idx=None):
"""
Get the position of a body.
Parameters:
body_name: str, optional - Name of the body.
body_idx: int, optional - Index of the body.
Returns:
np.ndarray - Position of the body.
"""
if body_name is not None:
body_idx = self.get_body_idx(body_name=body_name)
elif body_idx is not None:
body_idx = body_idx
else:
raise ValueError("Either body name or index should be given.")
return self.chain.nodes[body_idx]['p']
def get_p_batch_body(self,body_name=None,body_idx=None):
"""
Get batched positions of a body.
Parameters:
body_name: str, optional - Name of the body.
body_idx: int, optional - Index of the body.
Returns:
np.ndarray - Batched positions of the body.
"""
if body_name is not None:
body_idx = self.get_body_idx(body_name=body_name)
elif body_idx is not None:
body_idx = body_idx
else:
raise ValueError("Either body name or index should be given.")
return self.chain.nodes[body_idx]['p_batch']
def get_R_body(self,body_name=None,body_idx=None):
"""
Get the rotation matrix of a body.
Parameters:
body_name: str, optional - Name of the body.
body_idx: int, optional - Index of the body.
Returns:
np.ndarray - Rotation matrix of the body.
"""
if body_name is not None:
body_idx = self.get_body_idx(body_name=body_name)
elif body_idx is not None:
body_idx = body_idx
else:
raise ValueError("Either body name or index should be given.")
return self.chain.nodes[body_idx]['R']
def get_R_batch_body(self,body_name=None,body_idx=None):
"""
Get batched rotation matrices of a body.
Parameters:
body_name: str, optional - Name of the body.
body_idx: int, optional - Index of the body.
Returns:
np.ndarray - Batched rotation matrices of the body.
"""
if body_name is not None:
body_idx = self.get_body_idx(body_name=body_name)
elif body_idx is not None:
body_idx = body_idx
else:
raise ValueError("Either body name or index should be given.")
return self.chain.nodes[body_idx]['R_batch']
def get_T_body(self,body_name=None,body_idx=None):
"""
Get the transformation matrix of a body.
Parameters:
body_name: str, optional - Name of the body.
body_idx: int, optional - Index of the body.
Returns:
np.ndarray - Transformation matrix of the body.
"""
return pr2t(
p = self.get_p_body(body_name,body_idx),
R = self.get_R_body(body_name,body_idx),
)
def set_body_info(self,body_idx,key,value):
"""
Set body information using a key and value.
Parameters:
body_idx: int - Index of the body.
key: str - Key of the information.
value: Any - Value to set.
"""
self.chain.nodes[body_idx][key] = value
def set_p_body(self,body_name=None,body_idx=None,p=None):
"""
Set the position of a body.
Parameters:
body_name: str, optional - Name of the body.
body_idx: int, optional - Index of the body.
p: np.ndarray - Position to set.
"""
if body_name is not None:
body_idx = self.get_body_idx(body_name=body_name)
elif body_idx is not None:
body_idx = body_idx
else:
raise ValueError("Either body name or index should be given.")
self.chain.nodes[body_idx]['p'] = p
def set_R_body(self,body_name=None,body_idx=None,R=None):
"""
Set the rotation matrix of a body.
Parameters:
body_name: str, optional - Name of the body.
body_idx: int, optional - Index of the body.
R: np.ndarray - Rotation matrix to set.
"""
if body_name is not None:
body_idx = self.get_body_idx(body_name=body_name)
elif body_idx is not None:
body_idx = body_idx
else:
raise ValueError("Either body name or index should be given.")
self.chain.nodes[body_idx]['R'] = R
def set_T_body(self,body_name=None,body_idx=None,T=None):
"""
Set the transformation matrix of a body.
Parameters:
body_name: str, optional - Name of the body.
body_idx: int, optional - Index of the body.
T: np.ndarray - Transformation matrix to set.
"""
if body_name is not None:
body_idx = self.get_body_idx(body_name=body_name)
elif body_idx is not None:
body_idx = body_idx
else:
raise ValueError("Either body name or index should be given.")
self.chain.nodes[body_idx]['p'] = t2p(T)
self.chain.nodes[body_idx]['R'] = t2r(T)
def set_q_joint(self,joint_name,q):
"""
Set the joint angle for a joint.
Parameters:
joint_name: str - Name of the joint.
q: float - Joint angle to set.
"""
self.chain.nodes[self.get_joint_idx(joint_name)]['q'] = q
def set_qs_joints(self,joint_names,qs,forward=False,root_body_name=None):
"""
Set joint angles for multiple joints.
Parameters:
joint_names: list of str - Names of the joints.
qs: list or np.ndarray - Joint angles to set.
forward: bool, optional - Whether to perform forward kinematics after setting.
root_body_name: str, optional - Name of the root body.
"""
for (joint_name,q) in zip(joint_names,qs):
self.set_q_joint(joint_name=joint_name,q=q)
if forward:
self.forward(root_body_name=root_body_name)
def set_qs_joints_batch(self, joint_names, q_batch):
"""
Set batched joint angles for multiple joints.
Parameters:
joint_names: list of str - Names of the joints.
q_batch: np.ndarray, shape [B, N] - Batch of joint angles for all joints.
"""
B, N = q_batch.shape
for i, joint_name in enumerate(joint_names):
q_values = q_batch[:, i] # shape [B]
self.chain.nodes[self.get_joint_idx(joint_name)]['q_batch'] = q_values
def add_body(
self,
body_name = '',
parent_body_name = '',
joint_name = '',
p_offset = np.zeros(3),
R_offset = np.eye(3),
a = np.zeros(3),
):
"""
Add a body to the chain.
Parameters:
body_name: str - Name of the body.
parent_body_name: str - Name of the parent body.
joint_name: str - Name of the joint associated with the body.
p_offset: np.ndarray - Positional offset of the body.
R_offset: np.ndarray - Rotational offset of the body.
a: np.ndarray - Joint axis of the body.
"""
# Add new body (='node' in 'chain)
new_body_idx = self.get_n_body()
self.chain.add_node(new_body_idx) # add new node (=body)
# Update body information
body_info = {
'name':body_name,
'parent_name':parent_body_name,
'joint_name':joint_name,
'p_offset':p_offset,
'R_offset':R_offset,
'a':a,
'q':0.0,
'p':np.zeros(3),
'R':np.eye(3),
'parent':[],
'childs':[],
}
self.chain.update(nodes=[(new_body_idx,body_info)])
# Append body and joint names
self.body_names.append(body_name)
self.joint_names.append(joint_name)
# Add parent body (if necessary)
if (parent_body_name is not None) and (body_name != parent_body_name):
# Add parent index
parent_body_idx = self.get_body_idx(parent_body_name)
self.chain.nodes[new_body_idx]['parent'] = parent_body_idx
# Connect parent and child
self.chain.add_edge(u_of_edge=parent_body_idx,v_of_edge=new_body_idx)
# Append childs to the parent
parent_body_idx = self.get_body_idx(body_name)
parent_childs = self.chain.nodes[parent_body_idx]['childs']
parent_childs.append(new_body_idx)
def build_chain_from_mujoco_env(
self,
env = None,
init_first = True,
):
"""
Build a kinematic chain from a MuJoCo environment.
Parameters:
env: MuJoCo environment - The environment to parse.
init_first: bool, optional - Whether to reset the chain before building.
"""
# Reset first
if init_first:
self.init_chain(name=env.name)
for body_idx in range(env.n_body):
# Parse body information
body_name = env.body_names[body_idx]
body = env.model.body(body_name)
parent_body_name = env.body_names[body.parentid[0]]
p_body_offset,quat_body_offset = body.pos,body.quat
# Parse joint information
n_joint = body.jntnum # number of attached joints
if n_joint == 1:
joint = env.model.joint(body.jntadr[0]) # joint attached joint
joint_name = joint.name
p_joint_offset,joint_axis = joint.pos,joint.axis # currently not supported
else:
joint_name = ''
p_joint_offset,joint_axis = np.zeros(3),np.zeros(3)
# Add body to the chain
self.add_body(
body_name = body_name,
parent_body_name = parent_body_name,
joint_name = joint_name,
p_offset = p_body_offset,
R_offset = quat2r(quat_body_offset),
a = joint_axis,
)
def forward(self,joint_names=None,q=None,root_body_name=None):
"""
Perform forward kinematics.
Parameters:
joint_names: list of str, optional - Names of the joints.
q: np.ndarray, optional - Joint angles.
root_body_name: str, optional - Name of the root body.
"""
# Set joint angles (optional)
if joint_names is not None:
self.set_qs_joints(joint_names=joint_names,qs=q)
# Get root body index
if root_body_name is None:
root_body_idx = self.get_root_body_idx()
else:
root_body_idx = self.get_body_idx(root_body_name)
# Recuresively update p and R
for edge in dfs_edges(self.chain,source=root_body_idx):
idx_fr = edge[0]
idx_to = edge[1]
body_fr = self.get_body(idx_fr)
body_to = self.get_body(idx_to)
# Update p
p = body_fr['R']@body_to['p_offset'] + body_fr['p'] # [3]
self.set_body_info(body_idx=idx_to,key='p',value=p)
# Update R
a_to = body_to['a']
if abs(np.linalg.norm(a_to)-1) < 1e-6: # revolute joint
q_to = body_to['q'] # [1]
R = body_fr['R']@body_to['R_offset']@rodrigues(a=a_to,q_rad=q_to) # [3x3]
else:
R = body_fr['R']@body_to['R_offset'] # [3x3]
self.set_body_info(body_idx=idx_to,key='R',value=R)
def forward_batch(self, joint_names=None, q_batch=None, root_body_name=None):
"""
Perform batched forward kinematics.
Parameters:
joint_names: list of str, optional - Names of the joints.
q_batch: np.ndarray, shape [B, N], optional - Batch of joint angles for all joints.
root_body_name: str, optional - Name of the root body.
"""
if joint_names is not None:
self.set_qs_joints_batch(joint_names=joint_names, q_batch=q_batch)
root_body_idx = self.get_root_body_idx() if root_body_name is None else self.get_body_idx(root_body_name)
B = q_batch.shape[0] # Number of batches
# Initialize root body's p and R for batch
root_body = self.get_body(root_body_idx)
root_p = root_body['p'] # shape [3]
root_R = root_body['R'] # shape [3, 3]
# Broadcast root position and rotation to batch size
root_p_batch = np.tile(root_p, (B, 1)) # shape [B, 3]
root_R_batch = np.tile(root_R, (B, 1, 1)) # shape [B, 3, 3]
# Update batched position and rotation for the root body
root_body['p_batch'] = root_p_batch
root_body['R_batch'] = root_R_batch
for edge in dfs_edges(self.chain, source=root_body_idx):
idx_fr = edge[0]
idx_to = edge[1]
body_fr = self.get_body(idx_fr)
body_to = self.get_body(idx_to)
# Update p_batch
p_fr_batch = body_fr['p_batch'] # shape [B, 3]
R_fr_batch = body_fr['R_batch'] # shape [B, 3, 3]
p_offset = body_to['p_offset'] # shape [3]
p_to_batch = np.einsum('bij,j->bi', R_fr_batch, p_offset) + p_fr_batch # shape [B, 3]
body_to['p_batch'] = p_to_batch
# Update R_batch
a_to = body_to['a'] # shape [3]
if abs(np.linalg.norm(a_to) - 1) < 1e-6: # Revolute joint
q_to_batch = body_to['q_batch'] # shape [B]
R_to_batch = np.matmul(
np.matmul(R_fr_batch, body_to['R_offset']),
rodrigues_batch(a=np.tile(a_to, (B, 1)), q_rad=q_to_batch)
) # shape [B, 3, 3]
else:
R_to_batch = np.matmul(R_fr_batch, body_to['R_offset']) # shape [B, 3, 3]
body_to['R_batch'] = R_to_batch
def plot_graph(
self,
align = 'horizontal',
figsize = (6,4),
node_size = 100,
nose_font_size = 6,
node_colors = None,
title_font_size = 10,
root_on_top = True,
):
"""
Plot the kinematic chain graph.
Parameters:
align: str, optional - Layout alignment ('horizontal' or 'vertical').
figsize: tuple, optional - Figure size.
node_size: int, optional - Size of the nodes.
nose_font_size: int, optional - Font size for node labels.
node_colors: list, optional - Colors of the nodes.
title_font_size: int, optional - Font size of the title.
root_on_top: bool, optional - Whether to place the root node on top.
"""
n_body = self.get_n_body()
tree = self.chain
for layer, nodes in enumerate(nx.topological_generations(tree)):
for node in nodes:
tree.nodes[node]['layer'] = layer
pos = nx.multipartite_layout(
tree,
align = align,
scale = 1.0,
subset_key ='layer',
)
# Invert the tree so that the root node comes on the top
if root_on_top:
pos = {node: (x, -y) for node, (x, y) in pos.items()}
# Plot nodes
fig,ax = plt.subplots(figsize=figsize)
if node_colors is None: # set colors
node_colors = []
for body_idx in range(n_body):
a = self.get_body(body_idx)['a']
if np.linalg.norm(a) < 1e-6:
node_color = (1,1,1,0.5)
else:
node_color = [0,0,0,0.5]
node_color[np.argmax(a)] = 1
node_color = tuple(node_color)
node_colors.append(node_color)
nx.draw_networkx(
tree,
pos = pos,
ax = ax,
with_labels = True,
node_size = node_size,
font_size = nose_font_size,
node_color = node_colors,
linewidths = 1,
edgecolors = 'k'
)
ax.set_title('%s'%(tree.name),fontsize=title_font_size)
fig.tight_layout()
plt.show()
def plot_chain_mujoco(
self,
env, # <= assumed to be instantiated from 'MuJoCoParserClass'
plot_link = True,
r_link = 0.01,
rgba_link = (0,0,1,0.25),
root_body_name = None,
plot_body = True,
plot_body_name = False,
plot_body_axis = True,
axis_len = 0.025,
axis_width = 0.0025,
rate = 1.0,
plot_body_sphere = False,
r_body = 0.025,
rgba_body = (0,0,0,0.5),
plot_rev_axis = True,
r_axis = 0.01,
h_axis = 0.05,
alpha_axis = 0.5,
):
"""
Plot the kinematic chain in a MuJoCo environment.
Parameters:
env: MuJoCo environment - Environment to plot.
plot_link: bool, optional - Whether to plot links.
r_link: float, optional - Radius of the links.
rgba_link: tuple, optional - Color of the links (RGBA).
root_body_name: str, optional - Name of the root body.
plot_body: bool, optional - Whether to plot bodies.
plot_body_name: bool, optional - Whether to display body names.
plot_body_axis: bool, optional - Whether to plot body axes.
axis_len: float, optional - Length of the axes.
axis_width: float, optional - Width of the axes.
rate: float, optional - Scaling factor.
plot_body_sphere: bool, optional - Whether to plot bodies as spheres.
r_body: float, optional - Radius of the body spheres.
rgba_body: tuple, optional - Color of the bodies (RGBA).
plot_rev_axis: bool, optional - Whether to plot revolute axes.
r_axis: float, optional - Radius of the axes.
h_axis: float, optional - Height of the axes.
alpha_axis: float, optional - Transparency of the axes.
"""
if root_body_name is None:
root_body_idx = self.get_root_body_idx()
else:
root_body_idx = self.get_body_idx(body_name=root_body_name)
# Plot link
if plot_link:
for idx,edge in enumerate(dfs_edges(self.chain,source=root_body_idx)):
body_fr = self.get_body(edge[0])
body_to = self.get_body(edge[1])
env.plot_cylinder_fr2to(
p_fr = body_fr['p'],
p_to = body_to['p'],
r = r_link,
rgba = rgba_link,
)
# Plot body
if plot_body:
subtree = nx.dfs_tree(self.chain,source=root_body_idx)
for body_idx in list(subtree.nodes):
body = self.get_body(body_idx)
if plot_body_name:
body_name = body['name']
else:
body_name = ''
env.plot_T(
p = body['p'],
R = body['R'],
plot_axis = plot_body_axis,
axis_len = rate*axis_len,
axis_width = rate*axis_width,
plot_sphere = plot_body_sphere,
sphere_r = r_body,
sphere_rgba = rgba_body,
label = body_name,
)
# Plot revolute axis
if plot_rev_axis:
subtree = nx.dfs_tree(self.chain,source=root_body_idx)
for body_idx in list(subtree.nodes):
body = self.get_body(body_idx)
a = body['a']
if np.linalg.norm(a) > 1e-6:
p,R = body['p'],body['R']
p2 = p + R@a*rate*h_axis
axis_color = [0,0,0,alpha_axis]
axis_color[np.argmax(a)] = 1
env.plot_arrow_fr2to(p_fr=p,p_to=p2,r=rate*r_axis,rgba=axis_color)
@sjchoi86
Copy link
Author

Initialize numpy chain

nc = NumpyChainClass(name=env.name)
nc.build_chain_from_mujoco_env(env=env)

Forward kinematics

nc.forward(root_body_name='pelvis')

Plot graph

nc.plot_graph(figsize=(6,4),node_size=100,nose_font_size=6)
print ("Done.")
image

@sjchoi86
Copy link
Author

Set pelvis position

env.forward()
p_pelvis = env.get_p_body(body_name='pelvis')
nc.set_p_body(body_name='pelvis',p=p_pelvis)

Batched FK using NumPy Chain

B = 10000 # batchsize
q0_batch = np.random.randn(B,env.n_rev_joint) # [B x 37]
env.tic()
nc.forward_batch(joint_names=env.rev_joint_names,
q_batch=q0_batch,root_body_name='pelvis')
esec_batch = env.toc()

Batched FK using MuJoCo (GT)

env.tic()
for b_idx in range(B):
q0 = q0_batch[b_idx,:]
env.forward(q=q0,joint_names=env.rev_joint_names)
esec_seq = env.toc()
print ("Done. esec_batch:[%.2fms] and esec_seq[%.2fms] for [%d] batched FK"%
(esec_batch1000,esec_seq1000,B))

Check batched FK

body_name_to_check = 'right_six_link'
p_batch_nc_check = nc.get_p_batch_body(body_name=body_name_to_check) # [B x 3]
p_errs = np.zeros(B)
for b_idx in range(B):
q0 = q0_batch[b_idx,:]
env.forward(q=q0,joint_names=env.rev_joint_names)
p_env_check = env.get_p_body(body_name=body_name_to_check)
p_nc_check = p_batch_nc_check[b_idx,:]
p_err = np.linalg.norm(p_env_check-p_nc_check)
# Append error
p_errs[b_idx] = p_err
# Print
p_err_th = 1e-8
if p_err > p_err_th:
color_err = 'green' if p_err < p_err_th else 'red'
print ("[%d/%d] p_gt:[%.2f,%.2f,%.2f] p_nc:[%.2f,%.2f,%.2f] p_err:[%s]"%
(b_idx,B,p_env_check[0],p_env_check[1],p_env_check[2],
p_nc_check[0],p_nc_check[1],p_nc_check[2],
colored('%.2e'%(p_err),color_err)))

Print results

print ("[%d] batched FK and the maximum error is:[%.3e]"%
(B,p_errs.max()))
print ("Done.")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment