Last active
April 15, 2023 13:11
-
-
Save smsharma/9b17db5b2635973539d58129383c5f1f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "de1c7418-b1b2-4601-bccd-a7c6ace95481", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import jax\n", | |
"import jax.numpy as np\n", | |
"\n", | |
"import numpy as onp\n", | |
"\n", | |
"from scipy.interpolate import interp1d" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "b9d65511-b151-48c4-a873-89ec7186ab32", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def interp1d_jax(x, y, kind='linear', assume_sorted=False):\n", | |
" if kind != 'linear':\n", | |
" raise NotImplementedError('Only linear interpolation is supported.')\n", | |
"\n", | |
" if not assume_sorted:\n", | |
" sorted_indices = np.argsort(x)\n", | |
" x = x[sorted_indices]\n", | |
" y = y[sorted_indices]\n", | |
"\n", | |
" def interpolate(x_new):\n", | |
" if np.ndim(x_new) != 1:\n", | |
" raise ValueError(\"x_new should be a 1D array\")\n", | |
"\n", | |
" x_min = x[0]\n", | |
" x_max = x[-1]\n", | |
"\n", | |
" out_of_bounds = (x_new < x_min) | (x_new > x_max)\n", | |
"\n", | |
" def in_bounds(x_new):\n", | |
" indices = np.searchsorted(x, x_new, side='right') - 1\n", | |
" indices = np.clip(indices, 0, x.shape[0] - 2)\n", | |
"\n", | |
" x0, x1 = x[indices], x[indices + 1]\n", | |
" y0, y1 = y[indices], y[indices + 1]\n", | |
"\n", | |
" t = (x_new - x0) / (x1 - x0)\n", | |
" return (1 - t) * y0 + t * y1\n", | |
"\n", | |
" def out_of_bounds_func(x_new):\n", | |
" return np.zeros_like(x_new)\n", | |
"\n", | |
" return np.where(out_of_bounds, out_of_bounds_func(x_new), in_bounds(x_new))\n", | |
"\n", | |
" return interpolate" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "1f119e54-a929-49cf-814b-ad22ee016475", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Array([1.5000001, 1.5551683, 1.6164234, 1.6844373, 1.7599555, 1.8438063,\n", | |
" 1.9369088, 2.0402837, 2.155065 , 2.2825105, 2.4240181, 2.581139 ,\n", | |
" 2.755596 , 2.9493022, 3.1643806, 3.4031904, 3.6683497, 3.9627657,\n", | |
" 4.2896667, 4.652636 , 5.055655 , 5.503141 , 6. ], dtype=float32)" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x = np.linspace(0, 10, 100)\n", | |
"y = np.linspace(1, 6, 100)\n", | |
"\n", | |
"x_test = np.logspace(0, 1, 23)\n", | |
"\n", | |
"# Test Jax version\n", | |
"interp1d_jax(x, y)(x_test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "aeebab8d-c88b-4894-9b42-28204349364d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Array(True, dtype=bool)" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Make sure it matches scipy version\n", | |
"np.allclose(interp1d_jax(x, y)(x_test), interp1d(x, y)(x_test))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "4e84779b-9f2a-404b-8e5c-743ed0637a57", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Array(True, dtype=bool)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Make sure jit works\n", | |
"np.allclose(jax.jit(interp1d_jax(x, y))(x_test), interp1d_jax(x, y)(x_test))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "1a8a6a7f-bd47-45b0-9696-ba33003e6f32", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment