Skip to content

Instantly share code, notes, and snippets.

@kwinkunks
Created November 8, 2021 18:47
Show Gist options
  • Save kwinkunks/9fb0093496699dbeebd0411e16169701 to your computer and use it in GitHub Desktop.
Save kwinkunks/9fb0093496699dbeebd0411e16169701 to your computer and use it in GitHub Desktop.
Publishable graphics with Python
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "4a2559a5",
"metadata": {},
"source": [
"# \"Publishable graphics with MATLAB\"... in Python :P\n",
"\n",
"I'm trying to replicate [this nice tutorial by Martin Trauth](https://mres.uni-potsdam.de/index.php/2017/02/09/create-publishable-graphics-with-matlab/). \n",
"\n",
"It generates data and fits a model to produce this plot:\n",
"\n",
"<img src=\"http://141.89.112.21/wp-content/uploads/2017/02/publishable_graphcs_vs3.png\" width=600 />\n",
"\n",
"I'm going to stick somewhat closely to Martin's code, but with some exceptions:\n",
"\n",
"- I'll use [recommended best practice](https://towardsdatascience.com/stop-using-numpy-random-seed-581a9972805f) for NumPy's random number generation pattern (not just `np.random.seed(0)`, which is not safe).\n",
"- I'll use `matplotlib`'s object-oriented API, because it's more flexible than the procedural MATLAB style of the PyPlot interface.\n",
"- I'm not going to put all the variables into one array, because it's not my usual practice; I tend to see arrays as homogenous 'lumps' of data, not tables of data (I'd use `pandas` for that). This saves a lot of indexing to 'reach' into the `data` array."
]
},
{
"cell_type": "markdown",
"id": "08558328",
"metadata": {},
"source": [
"## Make some data"
]
},
{
"cell_type": "code",
"execution_count": 99,
"id": "9f427f45",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import scipy.optimize as so\n",
"\n",
"rng = np.random.default_rng(999)\n",
"\n",
"# How many data points do we want?\n",
"n = 26\n",
"\n",
"# Generate synthetic variables x, y and error e.\n",
"x = np.linspace(0.5, 3.0, n) + 0.2 * rng.standard_normal(n)\n",
"y = 3 + 0.2 * np.exp(x) + 0.5 * rng.standard_normal(n)\n",
"e = np.abs(rng.standard_normal(n))"
]
},
{
"cell_type": "markdown",
"id": "21187b6f",
"metadata": {},
"source": [
"Notice that `y` is generated here by the function \n",
"\n",
"$$ \\hat{y} = f(x) = c_0 + c_1 \\mathrm{e}^{x}$$\n",
"\n",
"where $c_0$ and $c_1$ are the coefficients of the function, equal to 3.0 and 0.2 respectively.\n",
"\n",
"This is the function we're going to try to fit. But we don't **know** the function, so I have not defined a Python function for it yet. We're pretending it's hidden &mdash; buried in the data, if you like."
]
},
{
"cell_type": "markdown",
"id": "bfe32977",
"metadata": {},
"source": [
"## Fit a model\n",
"\n",
"We'll use [`scipy.optimize.curve_fit()`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html) because it's convenient and I think it does exactly what we want.\n",
"\n",
"First we define the model as a Python function expressing the function $f$ above."
]
},
{
"cell_type": "code",
"execution_count": 100,
"id": "c4a1c9fc",
"metadata": {},
"outputs": [],
"source": [
"def model(t, c0, c1):\n",
" \"\"\"The non-linear model we're fitting.\"\"\"\n",
" return c0 + c1 * np.exp(t)\n",
"\n",
"# Define a domain.\n",
"t = np.linspace(x.min(), x.max(), n)\n",
"\n",
"# Unweighted solution.\n",
"(c0, c1), pcov = so.curve_fit(model, x, y, p0=(0, 0), method='lm')\n",
"y_hat_fit = model(t, c0, c1)\n",
"\n",
"# Weighted solution.\n",
"(c0, c1), pcov = so.curve_fit(model, x, y, p0=(0, 0), sigma=e, method='lm')\n",
"y_hat_fit_weighted = model(t, c0, c1)"
]
},
{
"cell_type": "markdown",
"id": "546b9e68",
"metadata": {},
"source": [
"We can take a look at these coefficients; we know they should be 3.0 and 0.2:"
]
},
{
"cell_type": "code",
"execution_count": 101,
"id": "6093168b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(2.954089846079857, 0.21970627478563665)"
]
},
"execution_count": 101,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c0, c1"
]
},
{
"cell_type": "markdown",
"id": "88350bd6",
"metadata": {},
"source": [
"## Make a plot"
]
},
{
"cell_type": "code",
"execution_count": 102,
"id": "c2e65142",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"fig, ax = plt.subplots(figsize=(8, 5))\n",
"\n",
"params = {'elinewidth': 0.75,\n",
" 'capsize': 3,\n",
" 'markersize': 15,\n",
" 'markerfacecolor': 'C0',\n",
" 'markeredgewidth': 0.75,\n",
" }\n",
"\n",
"ax.errorbar(x, y, yerr=e, fmt='k.', label='Data with errors', **params)\n",
"ax.plot(t, y_hat_fit, 'r', label='Nonlinear least squares')\n",
"ax.plot(t, y_hat_fit_weighted, 'gold', label='Weighted nonlinear least squares')\n",
"ax.tick_params(direction=\"in\")\n",
"ax.set_title('Comparison of unweighted and weighted fit', fontweight='bold')\n",
"ax.set_xlim(0, 3.5)\n",
"ax.set_xlabel('Depth in sediment [m]')\n",
"ax.set_ylim(0, 10)\n",
"ax.set_ylabel('Age of sediment [ka]')\n",
"ax.legend()\n",
"\n",
"plt.savefig('Figure.svg')\n",
"plt.savefig('Figure.png', dpi=300)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c8c7d1fd",
"metadata": {},
"source": [
"## Another way, using `least_squares`\n",
"\n",
"I'm not smart enough to know why this yields a different result from the unweighted least squares solution we get from `curve_fit()`, above."
]
},
{
"cell_type": "code",
"execution_count": 103,
"id": "fcf2e633",
"metadata": {},
"outputs": [],
"source": [
"def residual(phi, y, t):\n",
" return y - model(t, phi[0], phi[1])\n",
"\n",
"solution = so.least_squares(residual, x0=(0, 0), args=(y, t), method='lm')\n",
"\n",
"c0, c1 = solution.x\n",
"\n",
"y_hat_lsq = model(t, c0, c1)"
]
},
{
"cell_type": "code",
"execution_count": 106,
"id": "5d6a98d7",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"fig, ax = plt.subplots(figsize=(8, 5))\n",
"\n",
"params = {'elinewidth': 0.75,\n",
" 'capsize': 3,\n",
" 'markersize': 15,\n",
" 'markerfacecolor': 'C0',\n",
" 'markeredgewidth': 0.75,\n",
" }\n",
"\n",
"ax.errorbar(x, y, yerr=e, fmt='k.', label='Data with errors', **params)\n",
"ax.plot(t, y_hat_fit, 'r', label='Nonlinear least squares')\n",
"ax.plot(t, y_hat_fit_weighted, 'gold', label='Weighted nonlinear least squares')\n",
"ax.plot(t, y_hat_lsq, 'g--', label='least_squares() method')\n",
"ax.tick_params(direction=\"in\")\n",
"ax.set_title('Comparison of unweighted and weighted fit', fontweight='bold')\n",
"ax.set_xlim(0, 3.5)\n",
"ax.set_xlabel('Depth in sediment [m]')\n",
"ax.set_ylim(0, 10)\n",
"ax.set_ylabel('Age of sediment [ka]')\n",
"ax.legend()\n",
"\n",
"plt.savefig('Figure.svg')\n",
"plt.savefig('Figure.png', dpi=300)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "5cc82d8e",
"metadata": {},
"source": [
"---\n",
"[Thanks to Martin Trauth](http://mres.uni-potsdam.de/index.php/2017/02/09/create-publishable-graphics-with-matlab/)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment