Skip to content

Instantly share code, notes, and snippets.

@kwinkunks
Last active November 13, 2024 21:22
Show Gist options
  • Save kwinkunks/8c1590abb62b6e289fb01f2615497821 to your computer and use it in GitHub Desktop.
Save kwinkunks/8c1590abb62b6e289fb01f2615497821 to your computer and use it in GitHub Desktop.
Function differentiation chrestomathy
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "c03405d3-11a0-4a66-9649-5386870aa05d",
"metadata": {},
"source": [
"# Function differentiation\n",
"\n",
"If you have a Python function `f(x)` that evaluates the mathematical function $f$ at $x$, then we would like to find a function $\\nabla f$ so that we can evaluate $\\nabla f(x)$.\n",
"\n",
"Let's look at naive, symbolic and automatic differentiation:\n",
"\n",
"- Finite difference\n",
"- SymPy\n",
"- JAX\n",
"\n",
"## TODO\n",
"\n",
"- Compute on arrays\n",
"- Add plots\n",
"- Add forward and reverse mode automatic differentiation, eg [see this](https://kenndanielso.github.io/mlrefined/blog_posts/3_Automatic_differentiation/3_4_AD_forward_mode.html)\n",
"- Add [`tangent`](https://github.com/google/tangent) or whatever replaced it\n",
"- Add Torch\n",
"- Add Tensorflow"
]
},
{
"cell_type": "markdown",
"id": "e03fe0f4-66fc-4f28-860c-677b3dc6ae1e",
"metadata": {},
"source": [
"## Examples\n",
"\n",
"The examples are from [the Jax docs](https://jax.readthedocs.io/en/latest/automatic-differentiation.html).\n",
"\n",
"First, we look at $\\tanh x$, whose derivative is famously convenient to calculate: $1 - \\tanh^2 x$. \n",
"\n",
"Second, we look at a polynomial function. Then derivatives of $f(x) = x^3 + 2x^2 - 3x + 1$ can be represented as:\n",
"\n",
"$$\n",
"\\begin{array}{l}\n",
"f'(x) = 3x^2 + 4x -3\\\\\n",
"f''(x) = 6x + 4\\\\\n",
"f'''(x) = 6\\\\\n",
"f^{iv}(x) = 0\n",
"\\end{array}\n",
"$$"
]
},
{
"cell_type": "markdown",
"id": "d55fe244-d57c-4d3c-8256-44cf2ffd4e60",
"metadata": {},
"source": [
"## Analytic derivative"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6eeecbb3-de2b-4690-a9b5-1f12f762e6ee",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9640275800758169"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"\n",
"x = 2.0\n",
"\n",
"f = lambda x: np.tanh(x)\n",
"f(x)"
]
},
{
"cell_type": "markdown",
"id": "0fd98dae-bb8c-4f23-a0eb-c02fca703efd",
"metadata": {},
"source": [
"We might [look up](https://en.wikipedia.org/wiki/Hyperbolic_functions#Derivatives) that the derivative is given by $1 - \\tanh^2 x$"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4da7ae72-3227-4c8a-a73c-6e8761bb912e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.07065082485316443"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dfdx = 1 - f(x)**2\n",
"dfdx"
]
},
{
"cell_type": "markdown",
"id": "e161755a-3593-4974-bc4a-c97ed01f9961",
"metadata": {},
"source": [
"## Finite difference\n",
"\n",
"### One-sided\n",
"\n",
"We can approximate the analytic result with a finite difference, correct to a few decimal places:\n",
"\n",
"$$\n",
"f'(x) \\approx \\frac{f(x + \\epsilon) - f(x)}{\\epsilon}\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "30e40efe-bea7-4bcd-bb51-bc875dcf8143",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.07065075680046107"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ϵ = 1e-6\n",
"dfdx = (f(x + ϵ) - f(x)) / ϵ\n",
"dfdx"
]
},
{
"cell_type": "markdown",
"id": "967d771e-b660-4aae-8db6-4c304390ca0b",
"metadata": {},
"source": [
"### Two-sided\n",
"\n",
"We can improve slightly with a symmetric function, but it's still only an estimate, and we cannot get around the floating point imprecision.\n",
"\n",
"$$\n",
"f'(x) \\approx \\frac{f(x + \\epsilon) - f(x - \\epsilon)}{2 \\epsilon}\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "18a9ca9a-14f0-4c71-a1b8-2701d1ad3d48",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.07065082485713248"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ϵ = 1e-6\n",
"dfdx = (f(x + ϵ) - f(x - ϵ)) / (2 * ϵ)\n",
"dfdx"
]
},
{
"cell_type": "markdown",
"id": "58ed4e2a-f108-41fe-94d4-71e6e8373872",
"metadata": {},
"source": [
"## Symbolic differentiation with SymPy\n",
"\n",
"Examples inspired by [this Coursera tutorial](https://github.com/greyhatguy007/Mathematics-for-Machine-Learning-and-Data-Science-Specialization-Coursera/blob/main/C2/w1/C2_W1_Lab_1_differentiation_in_python.ipynb) by Luis Serrano."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "31001f99-5695-4706-b9eb-6a73e5c45dd4",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle \\tanh{\\left(x \\right)}$"
],
"text/plain": [
"tanh(x)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sympy import tanh, symbols, expand, diff, evalf\n",
"from sympy.utilities.lambdify import lambdify\n",
"\n",
"x = symbols('x')\n",
"f = tanh(x)\n",
"f"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "edcd8abb-3d50-47f1-9ee6-baebc8bf815e",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 0.964027580075817$"
],
"text/plain": [
"0.964027580075817"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f.evalf(subs={x: 2})"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "bc9bcfb5-e894-4ad4-ad54-72b629b03fac",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 1 - \\tanh^{2}{\\left(x \\right)}$"
],
"text/plain": [
"1 - tanh(x)**2"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dfdx = diff(f, x)\n",
"dfdx"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2f64ae31-cf46-4446-acf5-01a730a3dd58",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 0.0706508248531645$"
],
"text/plain": [
"0.0706508248531645"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dfdx.evalf(subs={x: 2})"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "ecec19b1-bd9a-434b-8768-cb66edb53efc",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 2 \\left(\\tanh^{2}{\\left(x \\right)} - 1\\right) \\tanh{\\left(x \\right)}$"
],
"text/plain": [
"2*(tanh(x)**2 - 1)*tanh(x)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"d2fdx = diff(f, x, 2) # Second derivative.\n",
"d2fdx"
]
},
{
"cell_type": "markdown",
"id": "7b040f1e-db4e-4ffd-a35b-ae61bbdaa020",
"metadata": {},
"source": [
"#### Polynomial example"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "4a20f504-eb1b-414f-bd77-710525330b11",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle x^{3} + 2 x^{2} - 3 x + 1$"
],
"text/plain": [
"x**3 + 2*x**2 - 3*x + 1"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = x**3 + 2*x**2 - 3*x + 1\n",
"f"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "029e92f2-4759-4952-9106-a0e17d0fc499",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 11.0$"
],
"text/plain": [
"11.0000000000000"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f.evalf(subs={x: 2})"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "b80fa332-5319-4a3f-8b17-d31c4c38644c",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 3 x^{2} + 4 x - 3$"
],
"text/plain": [
"3*x**2 + 4*x - 3"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dfdx = diff(f, x)\n",
"dfdx"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "32b5a138-dc23-478a-a192-b7c93ae24616",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$\\displaystyle 17.0$"
],
"text/plain": [
"17.0000000000000"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dfdx.evalf(subs={x: 2})"
]
},
{
"cell_type": "markdown",
"id": "b6b995e6-3b0c-4510-b1c8-668929382ea7",
"metadata": {},
"source": [
"## Automatic differentiation\n",
"\n",
"I learned about [dual numbers](https://en.wikipedia.org/wiki/Dual_number) and autodiff from Håvard Berland, who described the method at a company conference, in a version of [part of his PhD defense](https://www.pvv.ntnu.no/~berland/resources/autodiff-triallecture.pdf).\n",
"\n",
"Dual numbers are expressions of the form $a + b\\varepsilon$, where $a$ and $b$ are real numbers, and $\\varepsilon$ is a symbol taken to satisfy $\\varepsilon^2 = 0$ with $\\varepsilon\\neq 0$. Evaluating a function with dual numbers produces the derivative automatically:\n",
"\n",
"$$ P(a + b\\varepsilon) = P(a) + bP'(a)\\varepsilon $$\n",
"\n",
"So we choose $b = 1$.\n",
"\n",
"Dual numbers are implemented in some Python libraries, eg [`num-dual`](https://pypi.org/project/num-dual/), but there are others. However, it's not too hard to implement them ourselves."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "3c3f614e-c5bb-43ef-9b10-1cd55189ae9f",
"metadata": {},
"outputs": [],
"source": [
"class Dual:\n",
" def __init__(self, real, dual):\n",
" self.real = real\n",
" self.dual = dual\n",
" def __add__(self, other):\n",
" if isinstance(other, Dual):\n",
" return Dual(self.real + other.real, self.dual + other.dual)\n",
" else:\n",
" return Dual(self.real + other, self.dual)\n",
" __radd__ = __add__\n",
" def __mul__(self, other):\n",
" if isinstance(other, Dual):\n",
" return Dual(self.real * other.real,\n",
" self.real * other.dual + other.real * self.dual)\n",
" else:\n",
" return Dual(self.real * other, self.dual * other)\n",
" __rmul__ = __mul__\n",
" def __neg__(self):\n",
" return self.__mul__(-1)\n",
" def __sub__(self, other):\n",
" return self + -other\n",
" def __rsub__(self, other):\n",
" return other + -self\n",
" def __repr__(self):\n",
" return f'Dual({self.real}, {self.dual})'"
]
},
{
"cell_type": "markdown",
"id": "8a5924ec-bd88-4ba1-8ada-938408688cca",
"metadata": {},
"source": [
"Now we define our function, using multiplication instead of exponentiation (since we have not defined `__pow__()` in our class)."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "5d654e0e-62e5-4c71-a8bf-e80d4444dca8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"11"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = lambda x: x*x*x + 2*x*x - 3*x + 1\n",
"\n",
"# Evaluate with x = 2:\n",
"f(2)"
]
},
{
"cell_type": "markdown",
"id": "b6d25b60-57e8-45c5-83c8-ff7366333c56",
"metadata": {},
"source": [
"Now for the derivative at $x = 2$. Instead of evaluating on a real number, we evaluate on the dual number, $x = 2 + 1\\epsilon$"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "64049bba-faf1-4142-8ec4-5f5cad178776",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dual(11, 17)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = Dual(2, 1)\n",
"\n",
"f(x)"
]
},
{
"cell_type": "markdown",
"id": "6f8ad44f-9c7c-47ba-8c07-03d08830dd6f",
"metadata": {},
"source": [
"The dual part, 17, is the derivative."
]
},
{
"cell_type": "markdown",
"id": "a123d6a6-bcb6-4a56-b7b6-ffa9de291b4f",
"metadata": {},
"source": [
"## Automatic differentiation with Jax\n",
"\n",
"[The Jax docs](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) are really good. All the examples in this notebook were derived from there."
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "e5484830-0f04-4082-a33f-b5de1e6a8b7d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array(0.07065082, dtype=float32, weak_type=True)"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"from jax import grad\n",
"\n",
"grad_tanh = grad(jnp.tanh)\n",
"\n",
"x = 2.0\n",
"grad_tanh(x)"
]
},
{
"cell_type": "markdown",
"id": "02bfd32f-9607-4216-a82b-f03a4f2d1a5d",
"metadata": {},
"source": [
"We can get the second-order derivative:"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "740e07f0-1375-4c6b-884c-e90e67f06feb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array(-0.13621868, dtype=float32, weak_type=True)"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"grad(grad(jnp.tanh))(x)"
]
},
{
"cell_type": "markdown",
"id": "58898e5e-7b26-41de-8a13-ee84b3b056d5",
"metadata": {},
"source": [
"For the polynomial:"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "53756541-7421-46bf-9b46-46a6344a7dac",
"metadata": {},
"outputs": [],
"source": [
"f = lambda x: x**3 + 2*x**2 - 3*x + 1\n",
"\n",
"dfdx = jax.grad(f)\n",
"d2fdx = jax.grad(dfdx)\n",
"d3fdx = jax.grad(d2fdx)\n",
"d4fdx = jax.grad(d3fdx)"
]
},
{
"cell_type": "markdown",
"id": "9d78a6b7-7f36-4b92-ab0a-95ecfd594439",
"metadata": {},
"source": [
"We expect 4.0 here:"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "8442e216-4ed8-4e41-a1aa-a046a9576ae0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array(4., dtype=float32, weak_type=True)"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dfdx(1.0)"
]
},
{
"cell_type": "markdown",
"id": "fe5d2392-820a-4424-b1b5-7955ef2021b6",
"metadata": {},
"source": [
"---\n",
"\n",
"© Matt Hall 2024 and various original authors linked in text, original content licensed CC BY"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py311",
"language": "python",
"name": "py311"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment