Created
September 27, 2024 03:47
-
-
Save bridgesign/eae20508b602faa1eecaaef27015041a to your computer and use it in GitHub Desktop.
Torch Printer for sympy with full torch jit compatibility
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This is a modified version of the original file from the Modulus repository (NVIDIA) | |
# The original file can be found at: | |
# https://github.com/NVIDIA/modulus-sym/blob/main/modulus/sym/utils/sympy/torch_printer.py | |
# The original file is licensed under the Apache License, Version 2.0 | |
# The modified file is provided under the same license | |
# The original file disclaimer is provided below: | |
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. | |
# SPDX-FileCopyrightText: All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Helper functions for converting sympy equations to pytorch | |
""" | |
from sympy import fu, lambdify, Float, Integer | |
import sympy | |
import torch | |
import numpy as np | |
from typing import Callable, List, Dict, Optional | |
import linecache | |
def torch_lambdify(f, r, extra_funcs=None): | |
""" | |
generates a PyTorch function from a sympy equation | |
Parameters | |
---------- | |
f : Sympy Exp, float, int, bool | |
the equation to convert to torch. | |
If float, int, or bool this gets converted | |
to a constant function of value `f`. | |
r : list, dict | |
A list of the arguments for `f`. If dict then | |
the keys of the dict are used. | |
extra_funcs : dict | |
A dictionary of extra functions to include in the lambdify | |
Returns | |
------- | |
torch_f : PyTorch function | |
""" | |
if extra_funcs is None: | |
extra_funcs = {} | |
# try: | |
# f = float(f) | |
# except: | |
# pass | |
# if isinstance(f, (float, int, bool)): # constant function | |
# def loop_lambda(constant): | |
# return constant | |
# lambdify_f = loop_lambda(f) | |
# else: | |
vars = [k for k in r] | |
lambdify_f = lambdify(vars, f, [extra_funcs, TORCH_SYMPY_PRINTER]) | |
return lambdify_f | |
def _where_torch(conditions, x, y): | |
if isinstance(x, (int, float)): | |
x = float(x) * torch.ones(*conditions.get_shape()) | |
if isinstance(y, (int, float)): | |
y = float(y) * torch.ones(*conditions.get_shape()) | |
return torch.where(conditions, x, y) | |
def _heaviside_torch(x, values=0): | |
return torch.maximum(torch.sign(x), torch.zeros(1, device=x.device)) | |
def _sqrt_torch(x): | |
return torch.sqrt((x - 1e-6) * _heaviside_torch(x - 1e-6) + 1e-6) | |
# TODO: Add jit version here | |
def _or_torch(*x): | |
return_value = x[0] | |
for value in x: | |
return_value = torch.logical_or(return_value, value) | |
return return_value | |
# TODO: Add jit version here | |
def _and_torch(*x): | |
return_value = x[0] | |
for value in x: | |
return_value = torch.logical_and(return_value, value) | |
return return_value | |
@torch.jit.script | |
def _min_jit(x: List[torch.Tensor]): | |
assert len(x) > 0 | |
min_tensor = x[0] | |
for i in range(1, len(x)): | |
min_tensor = torch.minimum(min_tensor, x[i]) | |
return min_tensor | |
def _min_torch(*x): | |
# get tensor shape | |
for value in x: | |
if not isinstance(value, (int, float)): | |
tensor_shape = list(map(int, value.shape)) | |
device = value.device | |
# convert all floats and ints to tensor | |
x_only_tensors = [] | |
for value in x: | |
if isinstance(value, (int, float)): | |
value = torch.zeros(tensor_shape, device=device) + value | |
x_only_tensors.append(value) | |
# reduce min | |
min_tensor, _ = torch.min(torch.stack(x_only_tensors, -1), -1) | |
return min_tensor | |
# jit option | |
# return _min_jit(x_only_tensors) | |
# TODO: benchmark this other option that avoids stacking and extra memory movement | |
# Update: cannot jit this because TorchScript doesn't support functools.reduce | |
# return functools.reduce(torch.minimum, x) | |
@torch.jit.script | |
def _max_jit(x: List[torch.Tensor]): | |
assert len(x) > 0 | |
max_tensor = x[0] | |
for i in range(1, len(x)): | |
max_tensor = torch.maximum(max_tensor, x[i]) | |
return max_tensor | |
def _max_torch(*x): | |
# get tensor shape | |
for value in x: | |
if not isinstance(value, (int, float)): | |
tensor_shape = list(map(int, value.shape)) | |
device = value.device | |
# convert all floats and ints to tensor | |
x_only_tensors = [] | |
for value in x: | |
if isinstance(value, (int, float)): | |
value = (torch.zeros(tensor_shape) + value).to(device) | |
x_only_tensors.append(value) | |
# reduce max | |
max_tensor, _ = torch.max(torch.stack(x_only_tensors, -1), -1) | |
return max_tensor | |
TORCH_SYMPY_PRINTER = { | |
"abs": torch.abs, | |
"Abs": torch.abs, | |
"sign": torch.sign, | |
"ceiling": torch.ceil, | |
"floor": torch.floor, | |
"log": torch.log, | |
"exp": torch.exp, | |
"sqrt": _sqrt_torch, | |
"cos": torch.cos, | |
"acos": torch.acos, | |
"sin": torch.sin, | |
"asin": torch.asin, | |
"tan": torch.tan, | |
"atan": torch.atan, | |
"atan2": torch.atan2, | |
"cosh": torch.cosh, | |
"acosh": torch.acosh, | |
"sinh": torch.sinh, | |
"asinh": torch.asinh, | |
"tanh": torch.tanh, | |
"atanh": torch.atanh, | |
"erf": torch.erf, | |
"loggamma": torch.lgamma, | |
"Min": _min_torch, | |
"Max": _max_torch, | |
"Heaviside": _heaviside_torch, | |
"logical_or": _or_torch, | |
"logical_and": _and_torch, | |
"where": _where_torch, | |
"pi": np.pi, | |
"conjugate": torch.conj, | |
} | |
# Function to compile sympy expr to torch and create a script function | |
# Magic wrap function to create a torch script function | |
# Uses linecache to cache the code with exec!! | |
# This hack repackages the function to work with Dict and varargs | |
def sympy_torch_script( | |
expr: sympy.Expr, | |
keys: List[str], | |
extra_funcs: Optional[Dict] = None, | |
) -> Callable: | |
torch_expr = torch_lambdify(expr, keys, extra_funcs=extra_funcs) | |
torch_expr.__module__ = "torch" | |
filename = '<wrapped>-%s' % torch_expr.__code__.co_filename | |
funclocals = {} | |
namespace = {"Dict": Dict, "torch": torch, "func": torch_expr} | |
funcname = "_wrapped" | |
code = 'def %s(vars: Dict[str, torch.Tensor]):\n' % funcname | |
code += ' return func(' | |
for key in keys: | |
code += 'vars["%s"],' % key | |
code += ')\n' | |
c = compile(code, filename, 'exec') | |
exec(c, namespace, funclocals) | |
linecache.cache[filename] = ( | |
len(code), | |
None, | |
code.splitlines(keepends=True), | |
filename, | |
) | |
func = funclocals[funcname] | |
func.__module__ = "torch" | |
return func | |
_mapped_funcs = 0 | |
def torch_map( | |
funcs: List[Callable], | |
) -> Callable: | |
""" | |
Function to create a torch script function from a list of functions | |
Args: | |
funcs: list of functions to compile | |
""" | |
global _mapped_funcs | |
if len(funcs) == 0: | |
raise ValueError("No functions provided") | |
filename = '<mapped>-%s' % _mapped_funcs | |
_mapped_funcs += 1 | |
funclocals = {} | |
namespace = {f"_{i}": f for i, f in enumerate(funcs)} | |
namespace.update({"Dict": Dict, "torch": torch}) | |
funcname = "_mapped" | |
code = 'def %s(vars: Dict[str, torch.Tensor]):\n' % funcname | |
code += ' return (' | |
code += ','.join([f"_{i}(vars)" for i in range(len(funcs))]) | |
code += ')\n' | |
c = compile(code, filename, 'exec') | |
exec(c, namespace, funclocals) | |
linecache.cache[filename] = ( | |
len(code), | |
None, | |
code.splitlines(keepends=True), | |
filename, | |
) | |
func = funclocals[funcname] | |
func.__module__ = "torch" | |
return func | |
def sympy_torch_map( | |
exprs: List[sympy.Expr], | |
extra_funcs: Optional[Dict] = None, | |
) -> torch.jit.ScriptFunction: | |
""" | |
Compile a sympy expression to a torch function | |
Args: | |
expr: sympy expression to compile | |
extra_funcs: extra functions to include in the torch function | |
Returns: | |
torch function | |
""" | |
nodes = [] | |
for expr in exprs: | |
keys = sorted([k.name for k in expr.free_symbols]) | |
_func = sympy_torch_script(expr, keys, extra_funcs) | |
nodes.append(_func) | |
return torch.jit.script(torch_map(nodes)) | |
if __name__ == '__main__': | |
import sympy as sp | |
x, y = sp.symbols('x y') | |
exprs = [x**2 + y**2, x**2, y**2] | |
func = sympy_torch_map(exprs) | |
vars = {'x': torch.tensor(1.), 'y': torch.tensor(2.)} | |
print(func(vars)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment