Skip to content

Instantly share code, notes, and snippets.

@bridgesign
Created September 27, 2024 03:47
Show Gist options
  • Save bridgesign/eae20508b602faa1eecaaef27015041a to your computer and use it in GitHub Desktop.
Save bridgesign/eae20508b602faa1eecaaef27015041a to your computer and use it in GitHub Desktop.
Torch Printer for sympy with full torch jit compatibility
# This is a modified version of the original file from the Modulus repository (NVIDIA)
# The original file can be found at:
# https://github.com/NVIDIA/modulus-sym/blob/main/modulus/sym/utils/sympy/torch_printer.py
# The original file is licensed under the Apache License, Version 2.0
# The modified file is provided under the same license
# The original file disclaimer is provided below:
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Helper functions for converting sympy equations to pytorch
"""
from sympy import fu, lambdify, Float, Integer
import sympy
import torch
import numpy as np
from typing import Callable, List, Dict, Optional
import linecache
def torch_lambdify(f, r, extra_funcs=None):
"""
generates a PyTorch function from a sympy equation
Parameters
----------
f : Sympy Exp, float, int, bool
the equation to convert to torch.
If float, int, or bool this gets converted
to a constant function of value `f`.
r : list, dict
A list of the arguments for `f`. If dict then
the keys of the dict are used.
extra_funcs : dict
A dictionary of extra functions to include in the lambdify
Returns
-------
torch_f : PyTorch function
"""
if extra_funcs is None:
extra_funcs = {}
# try:
# f = float(f)
# except:
# pass
# if isinstance(f, (float, int, bool)): # constant function
# def loop_lambda(constant):
# return constant
# lambdify_f = loop_lambda(f)
# else:
vars = [k for k in r]
lambdify_f = lambdify(vars, f, [extra_funcs, TORCH_SYMPY_PRINTER])
return lambdify_f
def _where_torch(conditions, x, y):
if isinstance(x, (int, float)):
x = float(x) * torch.ones(*conditions.get_shape())
if isinstance(y, (int, float)):
y = float(y) * torch.ones(*conditions.get_shape())
return torch.where(conditions, x, y)
def _heaviside_torch(x, values=0):
return torch.maximum(torch.sign(x), torch.zeros(1, device=x.device))
def _sqrt_torch(x):
return torch.sqrt((x - 1e-6) * _heaviside_torch(x - 1e-6) + 1e-6)
# TODO: Add jit version here
def _or_torch(*x):
return_value = x[0]
for value in x:
return_value = torch.logical_or(return_value, value)
return return_value
# TODO: Add jit version here
def _and_torch(*x):
return_value = x[0]
for value in x:
return_value = torch.logical_and(return_value, value)
return return_value
@torch.jit.script
def _min_jit(x: List[torch.Tensor]):
assert len(x) > 0
min_tensor = x[0]
for i in range(1, len(x)):
min_tensor = torch.minimum(min_tensor, x[i])
return min_tensor
def _min_torch(*x):
# get tensor shape
for value in x:
if not isinstance(value, (int, float)):
tensor_shape = list(map(int, value.shape))
device = value.device
# convert all floats and ints to tensor
x_only_tensors = []
for value in x:
if isinstance(value, (int, float)):
value = torch.zeros(tensor_shape, device=device) + value
x_only_tensors.append(value)
# reduce min
min_tensor, _ = torch.min(torch.stack(x_only_tensors, -1), -1)
return min_tensor
# jit option
# return _min_jit(x_only_tensors)
# TODO: benchmark this other option that avoids stacking and extra memory movement
# Update: cannot jit this because TorchScript doesn't support functools.reduce
# return functools.reduce(torch.minimum, x)
@torch.jit.script
def _max_jit(x: List[torch.Tensor]):
assert len(x) > 0
max_tensor = x[0]
for i in range(1, len(x)):
max_tensor = torch.maximum(max_tensor, x[i])
return max_tensor
def _max_torch(*x):
# get tensor shape
for value in x:
if not isinstance(value, (int, float)):
tensor_shape = list(map(int, value.shape))
device = value.device
# convert all floats and ints to tensor
x_only_tensors = []
for value in x:
if isinstance(value, (int, float)):
value = (torch.zeros(tensor_shape) + value).to(device)
x_only_tensors.append(value)
# reduce max
max_tensor, _ = torch.max(torch.stack(x_only_tensors, -1), -1)
return max_tensor
TORCH_SYMPY_PRINTER = {
"abs": torch.abs,
"Abs": torch.abs,
"sign": torch.sign,
"ceiling": torch.ceil,
"floor": torch.floor,
"log": torch.log,
"exp": torch.exp,
"sqrt": _sqrt_torch,
"cos": torch.cos,
"acos": torch.acos,
"sin": torch.sin,
"asin": torch.asin,
"tan": torch.tan,
"atan": torch.atan,
"atan2": torch.atan2,
"cosh": torch.cosh,
"acosh": torch.acosh,
"sinh": torch.sinh,
"asinh": torch.asinh,
"tanh": torch.tanh,
"atanh": torch.atanh,
"erf": torch.erf,
"loggamma": torch.lgamma,
"Min": _min_torch,
"Max": _max_torch,
"Heaviside": _heaviside_torch,
"logical_or": _or_torch,
"logical_and": _and_torch,
"where": _where_torch,
"pi": np.pi,
"conjugate": torch.conj,
}
# Function to compile sympy expr to torch and create a script function
# Magic wrap function to create a torch script function
# Uses linecache to cache the code with exec!!
# This hack repackages the function to work with Dict and varargs
def sympy_torch_script(
expr: sympy.Expr,
keys: List[str],
extra_funcs: Optional[Dict] = None,
) -> Callable:
torch_expr = torch_lambdify(expr, keys, extra_funcs=extra_funcs)
torch_expr.__module__ = "torch"
filename = '<wrapped>-%s' % torch_expr.__code__.co_filename
funclocals = {}
namespace = {"Dict": Dict, "torch": torch, "func": torch_expr}
funcname = "_wrapped"
code = 'def %s(vars: Dict[str, torch.Tensor]):\n' % funcname
code += ' return func('
for key in keys:
code += 'vars["%s"],' % key
code += ')\n'
c = compile(code, filename, 'exec')
exec(c, namespace, funclocals)
linecache.cache[filename] = (
len(code),
None,
code.splitlines(keepends=True),
filename,
)
func = funclocals[funcname]
func.__module__ = "torch"
return func
_mapped_funcs = 0
def torch_map(
funcs: List[Callable],
) -> Callable:
"""
Function to create a torch script function from a list of functions
Args:
funcs: list of functions to compile
"""
global _mapped_funcs
if len(funcs) == 0:
raise ValueError("No functions provided")
filename = '<mapped>-%s' % _mapped_funcs
_mapped_funcs += 1
funclocals = {}
namespace = {f"_{i}": f for i, f in enumerate(funcs)}
namespace.update({"Dict": Dict, "torch": torch})
funcname = "_mapped"
code = 'def %s(vars: Dict[str, torch.Tensor]):\n' % funcname
code += ' return ('
code += ','.join([f"_{i}(vars)" for i in range(len(funcs))])
code += ')\n'
c = compile(code, filename, 'exec')
exec(c, namespace, funclocals)
linecache.cache[filename] = (
len(code),
None,
code.splitlines(keepends=True),
filename,
)
func = funclocals[funcname]
func.__module__ = "torch"
return func
def sympy_torch_map(
exprs: List[sympy.Expr],
extra_funcs: Optional[Dict] = None,
) -> torch.jit.ScriptFunction:
"""
Compile a sympy expression to a torch function
Args:
expr: sympy expression to compile
extra_funcs: extra functions to include in the torch function
Returns:
torch function
"""
nodes = []
for expr in exprs:
keys = sorted([k.name for k in expr.free_symbols])
_func = sympy_torch_script(expr, keys, extra_funcs)
nodes.append(_func)
return torch.jit.script(torch_map(nodes))
if __name__ == '__main__':
import sympy as sp
x, y = sp.symbols('x y')
exprs = [x**2 + y**2, x**2, y**2]
func = sympy_torch_map(exprs)
vars = {'x': torch.tensor(1.), 'y': torch.tensor(2.)}
print(func(vars))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment