Skip to content

Instantly share code, notes, and snippets.

@bmorris3
Created June 25, 2021 10:34
Show Gist options
  • Save bmorris3/198264ed8156299157cab0e9bce0c977 to your computer and use it in GitHub Desktop.
Save bmorris3/198264ed8156299157cab0e9bce0c977 to your computer and use it in GitHub Desktop.
Kitzmann et al. (2020) T-P profile parameterization
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "unknown-fantasy",
"metadata": {},
"source": [
"# Kitzmann et al. (2020)'s finite element approach to the T-P profile\n",
"\n",
"* Paper ref: https://arxiv.org/abs/1910.01070\n",
"* Code ref: https://github.com/exoclime/Helios-r2/blob/master/helios_src/additional/piecewise_poly.cpp"
]
},
{
"cell_type": "code",
"execution_count": 561,
"id": "great-stable",
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import pymc3 as pm\n",
"\n",
"gl_0 = [0.0]\n",
"gl_1 = [-1.0, 1.0]\n",
"gl_2 = [-1.0, 0.0, 1.0]\n",
"gl_3 = [-1.0, -0.447214, 0.447214, 1.0]\n",
"gl_4 = [-1.0, -0.654654, 0.654654, 1.0]\n",
"gl_5 = [-1.0, -0.654654, 0.0, 0.654654, 1.0]\n",
"gl_6 = [-1.0, -7.65055 -0.285232, 0.285232, 0.765055, 1.0]\n",
"\n",
"quadrature_nodes = [gl_0, gl_1, gl_2, gl_3, gl_4, gl_5, gl_6]\n",
"\n",
"\n",
"class Element(object): \n",
" def __init__(self, edges, order):\n",
" self.reference_vertices = quadrature_nodes[order]\n",
"\n",
" self.nb_dof = len(self.reference_vertices)\n",
"\n",
" self.dof_values = np.zeros(self.nb_dof)\n",
" self.dof_vertices = np.zeros(self.nb_dof)\n",
"\n",
" for i in range(0, self.nb_dof):\n",
" self.dof_vertices[i] = self.referenceElementMap(self.reference_vertices[i], edges[0], edges[1])\n",
"\n",
" def lagrangeBase(self, r, i):\n",
" l = 1\n",
"\n",
" for j in range(0, self.nb_dof):\n",
" if (i != j):\n",
" l *= ((r - self.reference_vertices[j]) / \n",
" (self.reference_vertices[i] - self.reference_vertices[j]))\n",
" return l\n",
"\n",
" def getValue(self, x):\n",
" # coordinate on the reference element\n",
" r = self.realElementMap(x, self.dof_vertices[0], self.dof_vertices[-1])\n",
" \n",
" y = 0\n",
"\n",
" for i in range(0, self.nb_dof):\n",
" y += self.dof_values[i] * self.lagrangeBase(r, i) \n",
"\n",
" return y\n",
" \n",
" # maps the coordinate value r on the reference element [-1, +1] to the real element [x_l, x_r]\n",
" def referenceElementMap(self, r, x_l, x_r): \n",
" return x_l + (1.0 + r)/2.0 * (x_r - x_l)\n",
" \n",
" # maps the coordinate value x on the real element [x_l, x_r] to the reference element [-1, +1]\n",
" def realElementMap(self, x, x_l, x_r): \n",
" return 2.0 * (x - x_l) / (x_r - x_l) - 1.0\n",
" \n",
" \n",
"class PiecewisePolynomial(object):\n",
" def __init__(self, element_number, polynomial_order, domain_boundaries, dof_values):\n",
" self.nb_elements = 0\n",
" self.nb_edges = 0\n",
" self.elements = []\n",
" log_boundaries = [np.log10(domain_boundaries[0]), np.log10(domain_boundaries[1])]\n",
"\n",
" self.nb_elements = element_number\n",
"\n",
" self.nb_edges = self.nb_elements + 1\n",
" self.order = polynomial_order\n",
" if (polynomial_order < 1): order = 1\n",
" if (polynomial_order > 6): order = 6\n",
" self.createElementGrid(log_boundaries)\n",
" self.setDOFvalues(dof_values)\n",
"\n",
" \n",
" def createElementGrid(self, domain_boundaries): \n",
" domain_size = domain_boundaries[0] - domain_boundaries[1]\n",
" element_size = domain_size / self.nb_elements\n",
"\n",
" element_edges = np.zeros(self.nb_elements+1)\n",
"\n",
" element_edges[0] = domain_boundaries[0]\n",
" element_edges[-1] = domain_boundaries[1]\n",
"\n",
"\n",
" for i in range(1, self.nb_edges-1):\n",
" element_edges[i] = element_edges[i-1] - element_size\n",
"\n",
"\n",
" self.elements = []\n",
" for i in range(0, self.nb_elements):\n",
" edges = [element_edges[i], element_edges[i+1]]\n",
" self.elements.append(Element(edges, self.order))\n",
"\n",
" self.dof_vertices = []\n",
" for i in range(0, self.nb_elements):\n",
" for j in range(0, self.elements[i].nb_dof-1):\n",
" self.dof_vertices.append(self.elements[i].dof_vertices[j])\n",
"\n",
" self.dof_vertices.append(self.elements[-1].dof_vertices[-1])\n",
" self.nb_dof = len(self.dof_vertices)\n",
" \n",
" def setDOFvalues(self, values):\n",
" if len(values) != self.nb_dof:\n",
" raise ValueError(\"Passed vector length does not correspond to the number of dof!\\n\")\n",
" \n",
" self.dof_values = values\n",
"\n",
" # set the dof values in each element\n",
" self.global_dof_index = 0\n",
"\n",
" for i in range(0, self.nb_elements):\n",
" for j in range(0, self.elements[i].nb_dof):\n",
" self.elements[i].dof_values[j] = self.dof_values[self.global_dof_index]\n",
" self.global_dof_index += 1\n",
"\n",
" self.global_dof_index -= 1 # ; //elements share a common boundary\n",
" \n",
" def getValue(self, x):\n",
" # The validity check below hasn't yet been implemented, but is preserved in comments:\n",
" # {\n",
" # //check validity range\n",
" # if (x > dof_vertices.front() || x < dof_vertices.back())\n",
" # {\n",
" # std::cout << \"Requested x value outside of domain of the polynomial!\\n\";\n",
"\n",
" # return 0.0;\n",
" # }\n",
"\n",
" # first, we check if x is a global DOF\n",
" for i in range(0, self.nb_dof):\n",
" if (self.dof_vertices[i] == x):\n",
" return self.dof_values[i]\n",
"\n",
" element_index = 0\n",
" \n",
" # if not, find the element it is in\n",
" for i in range(0, len(self.elements)):\n",
" # Assumes pressure is decreasing from beginning of array towards the end:\n",
" if (self.elements[i].dof_vertices[0] > x) and (self.elements[i].dof_vertices[-1] < x): \n",
" element_index = i\n",
" break\n",
"\n",
" # get the value from the corresponding element\n",
" return self.elements[element_index].getValue(x)"
]
},
{
"cell_type": "markdown",
"id": "average-refrigerator",
"metadata": {},
"source": [
"Usage example: "
]
},
{
"cell_type": "code",
"execution_count": 563,
"id": "gorgeous-denver",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots()\n",
"\n",
"for dof_vals in pm.TruncatedNormal.dist(mu=1000, sigma=500, upper=8000, lower=800).random(size=(10, 7)): \n",
" pp = PiecewisePolynomial(\n",
" element_number=3, polynomial_order=2, \n",
" domain_boundaries=[1e2, 1e-3], dof_values=np.sort(dof_vals)[::-1]\n",
" )\n",
" log_p = np.linspace(1.8, -2.8, 100)\n",
" \n",
" ax.scatter(pp.dof_values, 10**np.asarray(pp.dof_vertices))\n",
" ax.semilogy([pp.getValue(x) for x in log_p], 10**log_p, alpha=0.4)\n",
" ax.set(xlabel='Temperature', ylabel='Pressure')\n",
"ax.invert_yaxis()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "copyrighted-italic",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
@bmorris3
Copy link
Author

And a jax friendly version:

import numpy as np
import matplotlib.pyplot as plt
import pymc3 as pm
from jax import numpy as jnp
from jax.ops import index_update

gl_0 = [0.0]
gl_1 = [-1.0, 1.0]
gl_2 = [-1.0, 0.0, 1.0]
gl_3 = [-1.0, -0.447214, 0.447214, 1.0]
gl_4 = [-1.0, -0.654654, 0.654654, 1.0]
gl_5 = [-1.0, -0.654654, 0.0, 0.654654, 1.0]
gl_6 = [-1.0, -7.65055 -0.285232, 0.285232, 0.765055, 1.0]

quadrature_nodes = [gl_0, gl_1, gl_2, gl_3, gl_4, gl_5, gl_6]


class Element(object): 
    def __init__(self, edges, order):
        self.reference_vertices = quadrature_nodes[order]

        self.nb_dof = len(self.reference_vertices)

        self.dof_values = []
        self.dof_vertices = []

        for i in range(0, self.nb_dof):
            self.dof_vertices.append(self.referenceElementMap(self.reference_vertices[i], edges[0], edges[1]))

    def lagrangeBase(self, r, i):
        l = 1

        for j in range(0, self.nb_dof):
            if (i != j):
                l *= ((r - self.reference_vertices[j]) / 
                      (self.reference_vertices[i] - self.reference_vertices[j]))
        return l

    def getValue(self, x):
        # coordinate on the reference element
        r = self.realElementMap(x, self.dof_vertices[0], self.dof_vertices[-1])
        
        y = 0

        for i in range(0, self.nb_dof):
            y += self.dof_values[i] * self.lagrangeBase(r, i)  

        return y
            
    # maps the coordinate value r on the reference element [-1, +1] to the real element [x_l, x_r]
    def referenceElementMap(self, r, x_l, x_r): 
        return x_l + (1.0 + r)/2.0 * (x_r - x_l)
    
    # maps the coordinate value x on the real element [x_l, x_r] to the reference element [-1, +1]
    def realElementMap(self, x, x_l, x_r): 
        return 2.0 * (x - x_l) / (x_r - x_l) - 1.0
       
        
class PiecewisePolynomial(object):
    def __init__(self, element_number, polynomial_order, domain_boundaries, dof_values):
        self.nb_elements = 0
        self.nb_edges = 0
        self.elements = []
        log_boundaries = [jnp.log10(domain_boundaries[0]), jnp.log10(domain_boundaries[1])]

        self.nb_elements = element_number
        self.dof_vertices = []
        self.nb_edges = self.nb_elements + 1
        self.order = polynomial_order
#         if (polynomial_order < 1): order = 1
#         if (polynomial_order > 6): order = 6
        self.createElementGrid(log_boundaries)
        self.setDOFvalues(dof_values)

        
    def createElementGrid(self, domain_boundaries): 
        domain_size = domain_boundaries[0] - domain_boundaries[1]
        element_size = domain_size / self.nb_elements

        element_edges = []

        element_edges.append(domain_boundaries[0])

        for i in range(1, self.nb_edges-1):
            element_edges.append(element_edges[i-1] - element_size)
            
        element_edges.append(domain_boundaries[1])
        
        for i in range(0, self.nb_elements):
            edges = [element_edges[i], element_edges[i+1]]
            self.elements.append(Element(edges, self.order))

        for i in range(0, self.nb_elements):
            for j in range(0, self.elements[i].nb_dof-1):
                self.dof_vertices.append(self.elements[i].dof_vertices[j])

        self.dof_vertices.append(self.elements[-1].dof_vertices[-1])
        self.nb_dof = len(self.dof_vertices)
        
    def setDOFvalues(self, values):
        if len(values) != self.nb_dof:
            raise ValueError("Passed vector length does not correspond to the number of dof!\n")
        
        self.dof_values = values

        # set the dof values in each element
        self.global_dof_index = 0

        for i in range(0, self.nb_elements):
            for j in range(0, self.elements[i].nb_dof):
                self.elements[i].dof_values.append(self.dof_values[self.global_dof_index])
                self.global_dof_index += 1

            self.global_dof_index -= 1 # ; //elements share a common boundary
    
    def __call__(self, x_vector):
        x_lowers = jnp.array([self.elements[i].dof_vertices[-1] for i in range(len(self.elements))])
        x_uppers = jnp.array([self.elements[i].dof_vertices[0] for i in range(len(self.elements))])
        element_bools = jnp.where((x_vector < x_uppers[:, None]) & (x_vector > x_lowers[:, None]), True, False).T

        element_vals = jnp.array([[self.elements[i].getValue(x_vector[j]) for i in range(len(self.elements))]
                            for j in range(len(x_vector))])

        values = jnp.sum(
            jnp.where(element_bools, element_vals, 0), 
            axis=1
        )

        return values
    
from jax import jit


def piecewise_poly(log_p, domain_boundaries, dof_values, element_number, polynomial_order):
    pp = PiecewisePolynomial(
        element_number=element_number, polynomial_order=polynomial_order, 
        domain_boundaries=jnp.array(domain_boundaries), dof_values=jnp.sort(dof_values)[::-1]
    )
    return pp(jnp.asarray(log_p))

ppj = jit(piecewise_poly, static_argnums=(3, 4))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment