Skip to content

Instantly share code, notes, and snippets.

@xaedes
Last active October 21, 2019 15:03
Show Gist options
  • Save xaedes/17f5f1fb2a3d73347872f2707d4dbd30 to your computer and use it in GitHub Desktop.
Save xaedes/17f5f1fb2a3d73347872f2707d4dbd30 to your computer and use it in GitHub Desktop.
Generate python code with numba jit from sympy array expression
import sympy
import re
import numba
import numpy as np
def sympy_rot_x(angle):
cs, sn = sympy.cos(angle), sympy.sin(angle)
return np.array([
[1,0,0,0],
[0,cs,-sn,0],
[0,sn,cs,0],
[0,0,0,1]
])
def sympy_rot_y(angle):
cs,sn = sympy.cos(angle), sympy.sin(angle)
return np.array([
[cs,0,+sn,0],
[0,1,0,0],
[-sn,0,cs,0],
[0,0,0,1],
])
def sympy_rot_z(angle):
cs,sn = sympy.cos(angle), sympy.sin(angle)
return np.array([
[cs,-sn,0,0],
[+sn,cs,0,0],
[0,0,1,0],
[0,0,0,1],
])
def sympy_rollpitchyaw(roll, pitch, yaw):
return sympy_rot_x(roll).dot(sympy_rot_y(pitch)).dot(sympy_rot_z(yaw))
def name_for_expression(expr):
str_expr = str(expr)
str_expr = re.sub("[^\w\d_]","_",str_expr)
str_expr = re.sub("(^_+|_+$)","",str_expr)
str_expr = re.sub("__+", "_",str_expr)
return str_expr
def sp_to_numba(name, inputs, sp_array, initial_indent="", output_name = "outmatrix"):
# find common sub expressions
replacements, reduced_exprs = sympy.cse(sp_array.flatten())
# make better names
names = sympy.symbols(" ".join([name_for_expression(expr) for _,expr in replacements]))
# find common sub expressions but use the better names
replacements, reduced_exprs = sympy.cse(sp_array.flatten(),symbols=iter(names))
reduced_arr = np.array(reduced_exprs).reshape(sp_array.shape)
# generate code for common subexpressions
def_code = "\n".join([
initial_indent + "@numba.jit",
initial_indent + "def %s(%s):" % (name, ", ".join(map(str,list(inputs) + [output_name])))
])
indent = initial_indent + (" " * 4)
cse_code = "\n".join([
indent + str(symbol) + " = " + sympy.pycode(expr)
for symbol, expr in
replacements
])
arr_code = "\n".join([
indent + ("%s[%d,%d]" % (output_name,i,k)) + " = " + sympy.pycode(reduced_arr[i,k])
for i in range(sp_array.shape[0])
for k in range(sp_array.shape[1])
])
print(def_code)
print(cse_code)
print(arr_code)
return replacements, reduced_arr
sp_sym=sympy.symbols('roll pitch yaw')
sp_fun=sympy_rollpitchyaw(*sp_sym)
sp_to_numba("numba_"+sympy_rollpitchyaw.__name__, sp_sym, sp_fun)
# Output:
@numba.jit
def numba_sympy_rollpitchyaw(roll, pitch, yaw, outmatrix):
cos_pitch = math.cos(pitch)
cos_yaw = math.cos(yaw)
sin_yaw = math.sin(yaw)
sin_pitch = math.sin(pitch)
cos_roll = math.cos(roll)
x2_x4 = cos_roll*sin_yaw
sin_roll = math.sin(roll)
x1_x6 = cos_yaw*sin_roll
x1_x4 = cos_roll*cos_yaw
x2_x6 = sin_roll*sin_yaw
outmatrix[0,0] = cos_pitch*cos_yaw
outmatrix[0,1] = -cos_pitch*sin_yaw
outmatrix[0,2] = sin_pitch
outmatrix[0,3] = 0
outmatrix[1,0] = sin_pitch*x1_x6 + x2_x4
outmatrix[1,1] = -sin_pitch*x2_x6 + x1_x4
outmatrix[1,2] = -cos_pitch*sin_roll
outmatrix[1,3] = 0
outmatrix[2,0] = -sin_pitch*x1_x4 + x2_x6
outmatrix[2,1] = sin_pitch*x2_x4 + x1_x6
outmatrix[2,2] = cos_pitch*cos_roll
outmatrix[2,3] = 0
outmatrix[3,0] = 0
outmatrix[3,1] = 0
outmatrix[3,2] = 0
outmatrix[3,3] = 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment