Last active
November 13, 2024 21:22
-
-
Save kwinkunks/8c1590abb62b6e289fb01f2615497821 to your computer and use it in GitHub Desktop.
Function differentiation chrestomathy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "c03405d3-11a0-4a66-9649-5386870aa05d", | |
"metadata": {}, | |
"source": [ | |
"# Function differentiation\n", | |
"\n", | |
"If you have a Python function `f(x)` that evaluates the mathematical function $f$ at $x$, then we would like to find a function $\\nabla f$ so that we can evaluate $\\nabla f(x)$.\n", | |
"\n", | |
"Let's look at naive, symbolic and automatic differentiation:\n", | |
"\n", | |
"- Finite difference\n", | |
"- SymPy\n", | |
"- JAX\n", | |
"\n", | |
"## TODO\n", | |
"\n", | |
"- Compute on arrays\n", | |
"- Add plots\n", | |
"- Add forward and reverse mode automatic differentiation, eg [see this](https://kenndanielso.github.io/mlrefined/blog_posts/3_Automatic_differentiation/3_4_AD_forward_mode.html)\n", | |
"- Add [`tangent`](https://github.com/google/tangent) or whatever replaced it\n", | |
"- Add Torch\n", | |
"- Add Tensorflow" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e03fe0f4-66fc-4f28-860c-677b3dc6ae1e", | |
"metadata": {}, | |
"source": [ | |
"## Examples\n", | |
"\n", | |
"The examples are from [the Jax docs](https://jax.readthedocs.io/en/latest/automatic-differentiation.html).\n", | |
"\n", | |
"First, we look at $\\tanh x$, whose derivative is famously convenient to calculate: $1 - \\tanh^2 x$. \n", | |
"\n", | |
"Second, we look at a polynomial function. Then derivatives of $f(x) = x^3 + 2x^2 - 3x + 1$ can be represented as:\n", | |
"\n", | |
"$$\n", | |
"\\begin{array}{l}\n", | |
"f'(x) = 3x^2 + 4x -3\\\\\n", | |
"f''(x) = 6x + 4\\\\\n", | |
"f'''(x) = 6\\\\\n", | |
"f^{iv}(x) = 0\n", | |
"\\end{array}\n", | |
"$$" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "d55fe244-d57c-4d3c-8256-44cf2ffd4e60", | |
"metadata": {}, | |
"source": [ | |
"## Analytic derivative" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "6eeecbb3-de2b-4690-a9b5-1f12f762e6ee", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.9640275800758169" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"\n", | |
"x = 2.0\n", | |
"\n", | |
"f = lambda x: np.tanh(x)\n", | |
"f(x)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "0fd98dae-bb8c-4f23-a0eb-c02fca703efd", | |
"metadata": {}, | |
"source": [ | |
"We might [look up](https://en.wikipedia.org/wiki/Hyperbolic_functions#Derivatives) that the derivative is given by $1 - \\tanh^2 x$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "4da7ae72-3227-4c8a-a73c-6e8761bb912e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.07065082485316443" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dfdx = 1 - f(x)**2\n", | |
"dfdx" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e161755a-3593-4974-bc4a-c97ed01f9961", | |
"metadata": {}, | |
"source": [ | |
"## Finite difference\n", | |
"\n", | |
"### One-sided\n", | |
"\n", | |
"We can approximate the analytic result with a finite difference, correct to a few decimal places:\n", | |
"\n", | |
"$$\n", | |
"f'(x) \\approx \\frac{f(x + \\epsilon) - f(x)}{\\epsilon}\n", | |
"$$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "30e40efe-bea7-4bcd-bb51-bc875dcf8143", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.07065075680046107" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ϵ = 1e-6\n", | |
"dfdx = (f(x + ϵ) - f(x)) / ϵ\n", | |
"dfdx" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "967d771e-b660-4aae-8db6-4c304390ca0b", | |
"metadata": {}, | |
"source": [ | |
"### Two-sided\n", | |
"\n", | |
"We can improve slightly with a symmetric function, but it's still only an estimate, and we cannot get around the floating point imprecision.\n", | |
"\n", | |
"$$\n", | |
"f'(x) \\approx \\frac{f(x + \\epsilon) - f(x - \\epsilon)}{2 \\epsilon}\n", | |
"$$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "18a9ca9a-14f0-4c71-a1b8-2701d1ad3d48", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.07065082485713248" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ϵ = 1e-6\n", | |
"dfdx = (f(x + ϵ) - f(x - ϵ)) / (2 * ϵ)\n", | |
"dfdx" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "58ed4e2a-f108-41fe-94d4-71e6e8373872", | |
"metadata": {}, | |
"source": [ | |
"## Symbolic differentiation with SymPy\n", | |
"\n", | |
"Examples inspired by [this Coursera tutorial](https://github.com/greyhatguy007/Mathematics-for-Machine-Learning-and-Data-Science-Specialization-Coursera/blob/main/C2/w1/C2_W1_Lab_1_differentiation_in_python.ipynb) by Luis Serrano." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "31001f99-5695-4706-b9eb-6a73e5c45dd4", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle \\tanh{\\left(x \\right)}$" | |
], | |
"text/plain": [ | |
"tanh(x)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from sympy import tanh, symbols, expand, diff, evalf\n", | |
"from sympy.utilities.lambdify import lambdify\n", | |
"\n", | |
"x = symbols('x')\n", | |
"f = tanh(x)\n", | |
"f" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "edcd8abb-3d50-47f1-9ee6-baebc8bf815e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle 0.964027580075817$" | |
], | |
"text/plain": [ | |
"0.964027580075817" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"f.evalf(subs={x: 2})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "bc9bcfb5-e894-4ad4-ad54-72b629b03fac", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle 1 - \\tanh^{2}{\\left(x \\right)}$" | |
], | |
"text/plain": [ | |
"1 - tanh(x)**2" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dfdx = diff(f, x)\n", | |
"dfdx" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "2f64ae31-cf46-4446-acf5-01a730a3dd58", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle 0.0706508248531645$" | |
], | |
"text/plain": [ | |
"0.0706508248531645" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dfdx.evalf(subs={x: 2})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "ecec19b1-bd9a-434b-8768-cb66edb53efc", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle 2 \\left(\\tanh^{2}{\\left(x \\right)} - 1\\right) \\tanh{\\left(x \\right)}$" | |
], | |
"text/plain": [ | |
"2*(tanh(x)**2 - 1)*tanh(x)" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"d2fdx = diff(f, x, 2) # Second derivative.\n", | |
"d2fdx" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7b040f1e-db4e-4ffd-a35b-ae61bbdaa020", | |
"metadata": {}, | |
"source": [ | |
"#### Polynomial example" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "4a20f504-eb1b-414f-bd77-710525330b11", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle x^{3} + 2 x^{2} - 3 x + 1$" | |
], | |
"text/plain": [ | |
"x**3 + 2*x**2 - 3*x + 1" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"f = x**3 + 2*x**2 - 3*x + 1\n", | |
"f" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "029e92f2-4759-4952-9106-a0e17d0fc499", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle 11.0$" | |
], | |
"text/plain": [ | |
"11.0000000000000" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"f.evalf(subs={x: 2})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "b80fa332-5319-4a3f-8b17-d31c4c38644c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle 3 x^{2} + 4 x - 3$" | |
], | |
"text/plain": [ | |
"3*x**2 + 4*x - 3" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dfdx = diff(f, x)\n", | |
"dfdx" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "32b5a138-dc23-478a-a192-b7c93ae24616", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/latex": [ | |
"$\\displaystyle 17.0$" | |
], | |
"text/plain": [ | |
"17.0000000000000" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dfdx.evalf(subs={x: 2})" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b6b995e6-3b0c-4510-b1c8-668929382ea7", | |
"metadata": {}, | |
"source": [ | |
"## Automatic differentiation\n", | |
"\n", | |
"I learned about [dual numbers](https://en.wikipedia.org/wiki/Dual_number) and autodiff from Håvard Berland, who described the method at a company conference, in a version of [part of his PhD defense](https://www.pvv.ntnu.no/~berland/resources/autodiff-triallecture.pdf).\n", | |
"\n", | |
"Dual numbers are expressions of the form $a + b\\varepsilon$, where $a$ and $b$ are real numbers, and $\\varepsilon$ is a symbol taken to satisfy $\\varepsilon^2 = 0$ with $\\varepsilon\\neq 0$. Evaluating a function with dual numbers produces the derivative automatically:\n", | |
"\n", | |
"$$ P(a + b\\varepsilon) = P(a) + bP'(a)\\varepsilon $$\n", | |
"\n", | |
"So we choose $b = 1$.\n", | |
"\n", | |
"Dual numbers are implemented in some Python libraries, eg [`num-dual`](https://pypi.org/project/num-dual/), but there are others. However, it's not too hard to implement them ourselves." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "3c3f614e-c5bb-43ef-9b10-1cd55189ae9f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Dual:\n", | |
" def __init__(self, real, dual):\n", | |
" self.real = real\n", | |
" self.dual = dual\n", | |
" def __add__(self, other):\n", | |
" if isinstance(other, Dual):\n", | |
" return Dual(self.real + other.real, self.dual + other.dual)\n", | |
" else:\n", | |
" return Dual(self.real + other, self.dual)\n", | |
" __radd__ = __add__\n", | |
" def __mul__(self, other):\n", | |
" if isinstance(other, Dual):\n", | |
" return Dual(self.real * other.real,\n", | |
" self.real * other.dual + other.real * self.dual)\n", | |
" else:\n", | |
" return Dual(self.real * other, self.dual * other)\n", | |
" __rmul__ = __mul__\n", | |
" def __neg__(self):\n", | |
" return self.__mul__(-1)\n", | |
" def __sub__(self, other):\n", | |
" return self + -other\n", | |
" def __rsub__(self, other):\n", | |
" return other + -self\n", | |
" def __repr__(self):\n", | |
" return f'Dual({self.real}, {self.dual})'" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8a5924ec-bd88-4ba1-8ada-938408688cca", | |
"metadata": {}, | |
"source": [ | |
"Now we define our function, using multiplication instead of exponentiation (since we have not defined `__pow__()` in our class)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "5d654e0e-62e5-4c71-a8bf-e80d4444dca8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"11" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"f = lambda x: x*x*x + 2*x*x - 3*x + 1\n", | |
"\n", | |
"# Evaluate with x = 2:\n", | |
"f(2)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b6d25b60-57e8-45c5-83c8-ff7366333c56", | |
"metadata": {}, | |
"source": [ | |
"Now for the derivative at $x = 2$. Instead of evaluating on a real number, we evaluate on the dual number, $x = 2 + 1\\epsilon$" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "64049bba-faf1-4142-8ec4-5f5cad178776", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Dual(11, 17)" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x = Dual(2, 1)\n", | |
"\n", | |
"f(x)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6f8ad44f-9c7c-47ba-8c07-03d08830dd6f", | |
"metadata": {}, | |
"source": [ | |
"The dual part, 17, is the derivative." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a123d6a6-bcb6-4a56-b7b6-ffa9de291b4f", | |
"metadata": {}, | |
"source": [ | |
"## Automatic differentiation with Jax\n", | |
"\n", | |
"[The Jax docs](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) are really good. All the examples in this notebook were derived from there." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"id": "e5484830-0f04-4082-a33f-b5de1e6a8b7d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Array(0.07065082, dtype=float32, weak_type=True)" | |
] | |
}, | |
"execution_count": 32, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import jax\n", | |
"import jax.numpy as jnp\n", | |
"from jax import grad\n", | |
"\n", | |
"grad_tanh = grad(jnp.tanh)\n", | |
"\n", | |
"x = 2.0\n", | |
"grad_tanh(x)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "02bfd32f-9607-4216-a82b-f03a4f2d1a5d", | |
"metadata": {}, | |
"source": [ | |
"We can get the second-order derivative:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"id": "740e07f0-1375-4c6b-884c-e90e67f06feb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Array(-0.13621868, dtype=float32, weak_type=True)" | |
] | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"grad(grad(jnp.tanh))(x)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "58898e5e-7b26-41de-8a13-ee84b3b056d5", | |
"metadata": {}, | |
"source": [ | |
"For the polynomial:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"id": "53756541-7421-46bf-9b46-46a6344a7dac", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"f = lambda x: x**3 + 2*x**2 - 3*x + 1\n", | |
"\n", | |
"dfdx = jax.grad(f)\n", | |
"d2fdx = jax.grad(dfdx)\n", | |
"d3fdx = jax.grad(d2fdx)\n", | |
"d4fdx = jax.grad(d3fdx)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9d78a6b7-7f36-4b92-ab0a-95ecfd594439", | |
"metadata": {}, | |
"source": [ | |
"We expect 4.0 here:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"id": "8442e216-4ed8-4e41-a1aa-a046a9576ae0", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Array(4., dtype=float32, weak_type=True)" | |
] | |
}, | |
"execution_count": 35, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"dfdx(1.0)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "fe5d2392-820a-4424-b1b5-7955ef2021b6", | |
"metadata": {}, | |
"source": [ | |
"---\n", | |
"\n", | |
"© Matt Hall 2024 and various original authors linked in text, original content licensed CC BY" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "py311", | |
"language": "python", | |
"name": "py311" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment