Last active
October 21, 2019 15:03
-
-
Save xaedes/17f5f1fb2a3d73347872f2707d4dbd30 to your computer and use it in GitHub Desktop.
Generate python code with numba jit from sympy array expression
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sympy | |
import re | |
import numba | |
import numpy as np | |
def sympy_rot_x(angle): | |
cs, sn = sympy.cos(angle), sympy.sin(angle) | |
return np.array([ | |
[1,0,0,0], | |
[0,cs,-sn,0], | |
[0,sn,cs,0], | |
[0,0,0,1] | |
]) | |
def sympy_rot_y(angle): | |
cs,sn = sympy.cos(angle), sympy.sin(angle) | |
return np.array([ | |
[cs,0,+sn,0], | |
[0,1,0,0], | |
[-sn,0,cs,0], | |
[0,0,0,1], | |
]) | |
def sympy_rot_z(angle): | |
cs,sn = sympy.cos(angle), sympy.sin(angle) | |
return np.array([ | |
[cs,-sn,0,0], | |
[+sn,cs,0,0], | |
[0,0,1,0], | |
[0,0,0,1], | |
]) | |
def sympy_rollpitchyaw(roll, pitch, yaw): | |
return sympy_rot_x(roll).dot(sympy_rot_y(pitch)).dot(sympy_rot_z(yaw)) | |
def name_for_expression(expr): | |
str_expr = str(expr) | |
str_expr = re.sub("[^\w\d_]","_",str_expr) | |
str_expr = re.sub("(^_+|_+$)","",str_expr) | |
str_expr = re.sub("__+", "_",str_expr) | |
return str_expr | |
def sp_to_numba(name, inputs, sp_array, initial_indent="", output_name = "outmatrix"): | |
# find common sub expressions | |
replacements, reduced_exprs = sympy.cse(sp_array.flatten()) | |
# make better names | |
names = sympy.symbols(" ".join([name_for_expression(expr) for _,expr in replacements])) | |
# find common sub expressions but use the better names | |
replacements, reduced_exprs = sympy.cse(sp_array.flatten(),symbols=iter(names)) | |
reduced_arr = np.array(reduced_exprs).reshape(sp_array.shape) | |
# generate code for common subexpressions | |
def_code = "\n".join([ | |
initial_indent + "@numba.jit", | |
initial_indent + "def %s(%s):" % (name, ", ".join(map(str,list(inputs) + [output_name]))) | |
]) | |
indent = initial_indent + (" " * 4) | |
cse_code = "\n".join([ | |
indent + str(symbol) + " = " + sympy.pycode(expr) | |
for symbol, expr in | |
replacements | |
]) | |
arr_code = "\n".join([ | |
indent + ("%s[%d,%d]" % (output_name,i,k)) + " = " + sympy.pycode(reduced_arr[i,k]) | |
for i in range(sp_array.shape[0]) | |
for k in range(sp_array.shape[1]) | |
]) | |
print(def_code) | |
print(cse_code) | |
print(arr_code) | |
return replacements, reduced_arr | |
sp_sym=sympy.symbols('roll pitch yaw') | |
sp_fun=sympy_rollpitchyaw(*sp_sym) | |
sp_to_numba("numba_"+sympy_rollpitchyaw.__name__, sp_sym, sp_fun) | |
# Output: | |
@numba.jit | |
def numba_sympy_rollpitchyaw(roll, pitch, yaw, outmatrix): | |
cos_pitch = math.cos(pitch) | |
cos_yaw = math.cos(yaw) | |
sin_yaw = math.sin(yaw) | |
sin_pitch = math.sin(pitch) | |
cos_roll = math.cos(roll) | |
x2_x4 = cos_roll*sin_yaw | |
sin_roll = math.sin(roll) | |
x1_x6 = cos_yaw*sin_roll | |
x1_x4 = cos_roll*cos_yaw | |
x2_x6 = sin_roll*sin_yaw | |
outmatrix[0,0] = cos_pitch*cos_yaw | |
outmatrix[0,1] = -cos_pitch*sin_yaw | |
outmatrix[0,2] = sin_pitch | |
outmatrix[0,3] = 0 | |
outmatrix[1,0] = sin_pitch*x1_x6 + x2_x4 | |
outmatrix[1,1] = -sin_pitch*x2_x6 + x1_x4 | |
outmatrix[1,2] = -cos_pitch*sin_roll | |
outmatrix[1,3] = 0 | |
outmatrix[2,0] = -sin_pitch*x1_x4 + x2_x6 | |
outmatrix[2,1] = sin_pitch*x2_x4 + x1_x6 | |
outmatrix[2,2] = cos_pitch*cos_roll | |
outmatrix[2,3] = 0 | |
outmatrix[3,0] = 0 | |
outmatrix[3,1] = 0 | |
outmatrix[3,2] = 0 | |
outmatrix[3,3] = 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment