Skip to content

Instantly share code, notes, and snippets.

@avivajpeyi
Last active October 4, 2024 07:52
Show Gist options
  • Save avivajpeyi/b0a99b3f54b6841fb37e60af312b58a4 to your computer and use it in GitHub Desktop.
Save avivajpeyi/b0a99b3f54b6841fb37e60af312b58a4 to your computer and use it in GitHub Desktop.
wdm_transform.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyO+z+jbBgidS18Pn0bqY3aQ",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/avivajpeyi/b0a99b3f54b6841fb37e60af312b58a4/wdm_transform.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Comparing JAX and Numba WDM Transforms\n",
"\n",
"\n",
"## Common functions"
],
"metadata": {
"id": "H6YcKEr5mHIV"
}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "0gJu4ONiQ6sB",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "b6820784-3486-4621-d6af-4376b837f747"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Running JAX on cpu\n"
]
}
],
"source": [
"\"\"\" COMMON FUNCS \"\"\"\n",
"import numpy as np\n",
"from numpy import fft\n",
"from scipy.special import betainc\n",
"from scipy.signal import chirp\n",
"from typing import List, Tuple\n",
"import jax\n",
"jax.config.update('jax_enable_x64', False)\n",
"\n",
"DEVICE = jax.devices()[0].device_kind\n",
"print(f\"Running JAX on {DEVICE}\")\n",
"\n",
"PI = np.pi\n",
"FREQ_RANGE = [20, 100]\n",
"\n",
"def phitilde_vec_norm(Nf: int, Nt: int, dt: float, d: float) -> np.ndarray:\n",
" ND = Nf * Nt\n",
" omegas = 2 * np.pi / ND * np.arange(0, Nt // 2 + 1)\n",
" u_phit = phitilde_vec(omegas, Nf, dt, d)\n",
" normalising_factor = np.pi ** (-1 / 2) # Ollie's normalising factor\n",
" return u_phit / (normalising_factor)\n",
"\n",
"def phitilde_vec(omega: np.ndarray, Nf: int, dt: float, d: float = 4.0) -> np.ndarray:\n",
" \"\"\"Compute phi_tilde(omega_i) array.\"\"\"\n",
" dF = 1.0 / (2 * Nf * dt)\n",
" dOmega = 2 * PI * dF\n",
" inverse_sqrt_dOmega = 1.0 / np.sqrt(dOmega)\n",
" A, B = dOmega / 4, 3 * dOmega / 4\n",
" if B <= 0:\n",
" raise ValueError(\"B must be greater than 0\")\n",
" phi = np.full_like(omega, inverse_sqrt_dOmega)\n",
" mask = (A <= np.abs(omega)) & (np.abs(omega) < A + B)\n",
" phi[mask] *= np.cos((PI / 2.0) * __nu_d(omega[mask], A, B, d))\n",
" return phi\n",
"\n",
"def __nu_d(omega: np.ndarray, A: float, B: float, d: float = 4.0) -> np.ndarray:\n",
" \"\"\"Compute the normalized incomplete beta function.\"\"\"\n",
" x = (np.abs(omega) - A) / B\n",
" return betainc(d, d, x) / betainc(d, d, 1)\n",
"\n",
"\n",
"def simulate_data(Nf):\n",
" # assert Nf is power of 2\n",
" assert Nf & (Nf - 1) == 0, \"Nf must be a power of 2\"\n",
" fs = 512\n",
" dt = 1 / fs\n",
" Nt = Nf\n",
" mult = 16\n",
" nx = 4.0\n",
" ND = Nt * Nf\n",
" t = np.arange(0, ND) * dt\n",
" y = chirp(t, f0=FREQ_RANGE[0], f1=FREQ_RANGE[1], t1=t[-1], method=\"quadratic\")\n",
" phif = phitilde_vec_norm(Nf, Nt, dt=dt, d=nx)\n",
" yf = fft.fft(y)[:ND//2+1]\n",
" return yf, phif\n",
"\n"
]
},
{
"cell_type": "markdown",
"source": [
"## NUMBA Transform"
],
"metadata": {
"id": "4JhI5f_A3p2V"
}
},
{
"cell_type": "code",
"source": [
"\"\"\"NUMBA VERSION\"\"\"\n",
"import numpy as np\n",
"from numba import njit\n",
"from numpy import fft\n",
"\n",
"\n",
"def transform_wavelet_freq_helper_numba(\n",
" data: np.ndarray, Nf: int, Nt: int, phif: np.ndarray\n",
") -> np.ndarray:\n",
" \"\"\"helper to do the wavelet transform using the fast wavelet domain transform\"\"\"\n",
" wave = np.zeros((Nt, Nf)) # wavelet wavepacket transform of the signal\n",
"\n",
" DX = np.zeros(Nt, dtype=np.complex128)\n",
" freq_strain = data.copy() # Convert\n",
" for f_bin in range(0, Nf + 1):\n",
" __fill_wave_1_numba(f_bin, Nt, Nf, DX, freq_strain, phif)\n",
" DX_trans = fft.ifft(\n",
" DX, Nt\n",
" ) # A fix because numba doesn't support np.fft\n",
" __fill_wave_2_numba(f_bin, DX_trans, wave, Nt, Nf)\n",
"\n",
" return wave\n",
"\n",
"\n",
"@njit()\n",
"def __fill_wave_1_numba(\n",
" f_bin: int,\n",
" Nt: int,\n",
" Nf: int,\n",
" DX: np.ndarray,\n",
" data: np.ndarray,\n",
" phif: np.ndarray,\n",
") -> None:\n",
" \"\"\"helper for assigning DX in the main loop\"\"\"\n",
" i_base = Nt // 2\n",
" jj_base = f_bin * Nt // 2\n",
"\n",
" if f_bin == 0 or f_bin == Nf:\n",
" # NOTE this term appears to be needed to recover correct constant (at least for m=0), but was previously missing\n",
" DX[Nt // 2] = phif[0] * data[f_bin * Nt // 2] / 2.0\n",
" else:\n",
" DX[Nt // 2] = phif[0] * data[f_bin * Nt // 2]\n",
"\n",
" for jj in range(jj_base + 1 - Nt // 2, jj_base + Nt // 2):\n",
" j = np.abs(jj - jj_base)\n",
" i = i_base - jj_base + jj\n",
" if f_bin == Nf and jj > jj_base:\n",
" DX[i] = 0.0\n",
" elif f_bin == 0 and jj < jj_base:\n",
" DX[i] = 0.0\n",
" elif j == 0:\n",
" continue\n",
" else:\n",
" DX[i] = phif[j] * data[jj]\n",
"\n",
"\n",
"@njit()\n",
"def __fill_wave_2_numba(\n",
" f_bin: int, DX_trans: np.ndarray, wave: np.ndarray, Nt: int, Nf: int\n",
") -> None:\n",
" if f_bin == 0:\n",
" # half of lowest and highest frequency bin pixels are redundant, so store them in even and odd components of f_bin=0 respectively\n",
" for n in range(0, Nt, 2):\n",
" wave[n, 0] = np.real(DX_trans[n] * np.sqrt(2))\n",
" elif f_bin == Nf:\n",
" for n in range(0, Nt, 2):\n",
" wave[n + 1, 0] = np.real(DX_trans[n] * np.sqrt(2))\n",
" else:\n",
" for n in range(0, Nt):\n",
" if f_bin % 2:\n",
" if (n + f_bin) % 2:\n",
" wave[n, f_bin] = -np.imag(DX_trans[n])\n",
" else:\n",
" wave[n, f_bin] = np.real(DX_trans[n])\n",
" else:\n",
" if (n + f_bin) % 2:\n",
" wave[n, f_bin] = np.imag(DX_trans[n])\n",
" else:\n",
" wave[n, f_bin] = np.real(DX_trans[n])\n",
"\n",
"\n",
"\n"
],
"metadata": {
"id": "RkJzTw0t3pPw"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## JAX WDM transform"
],
"metadata": {
"id": "u9mBzg9Vmhgm"
}
},
{
"cell_type": "code",
"source": [
"\"\"\"JAX VERSION\"\"\"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from functools import partial\n",
"from jax import jit\n",
"from jax.numpy.fft import ifft\n",
"\n",
"@partial(jit, static_argnames=('Nf', 'Nt'))\n",
"def transform_wavelet_freq_helper_JAX(\n",
" data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray\n",
") -> jnp.ndarray:\n",
" wave = jnp.zeros((Nt, Nf))\n",
" f_bins = jnp.arange(Nf)\n",
"\n",
" i_base = Nt // 2\n",
" jj_base = f_bins * Nt // 2\n",
"\n",
" initial_values = jnp.where(\n",
" (f_bins == 0) | (f_bins == Nf),\n",
" phif[0] * data[f_bins * Nt // 2] / 2.0,\n",
" phif[0] * data[f_bins * Nt // 2]\n",
" )\n",
"\n",
" DX = jnp.zeros((Nf, Nt), dtype=jnp.complex64)\n",
" DX = DX.at[:, Nt // 2].set(initial_values)\n",
"\n",
" j_range = jnp.arange(1 - Nt // 2, Nt // 2)\n",
" j = jnp.abs(j_range)\n",
" i = i_base + j_range\n",
"\n",
" cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0)\n",
" cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0)\n",
" cond3 = j[None, :] == 0\n",
"\n",
" jj = jj_base[:, None] + j_range[None, :]\n",
" val = jnp.where(cond1 | cond2, 0.0, phif[j] * data[jj])\n",
" DX = DX.at[:, i].set(jnp.where(cond3, DX[:, i], val))\n",
"\n",
" # Vectorized ifft\n",
" DX_trans = ifft(DX, axis=1)\n",
"\n",
" # Vectorized __fill_wave_2_jax\n",
" n_range = jnp.arange(Nt)\n",
" cond1 = (n_range[:, None] + f_bins[None, :]) % 2 == 1\n",
" cond2 = f_bins % 2 == 1\n",
"\n",
" real_part = jnp.where(cond2, -jnp.imag(DX_trans), jnp.real(DX_trans))\n",
" imag_part = jnp.where(cond2, jnp.real(DX_trans), jnp.imag(DX_trans))\n",
"\n",
" wave = jnp.where(cond1, imag_part, real_part)\n",
"\n",
" # Special cases for f_bin 0 and Nf\n",
" wave = wave.at[::2, 0].set(jnp.real(DX_trans[0, ::2] * jnp.sqrt(2)))\n",
" wave = wave.at[1::2, -1].set(jnp.real(DX_trans[-1, ::2] * jnp.sqrt(2)))\n",
"\n",
" return wave.T\n"
],
"metadata": {
"id": "Cf4wt-8dX9Xb"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Cupy version"
],
"metadata": {
"id": "b_5FzcvgLHr2"
}
},
{
"cell_type": "code",
"source": [
"try:\n",
" import cupy as cp\n",
" cupy_available = True\n",
"except:\n",
" cupy_available = False\n",
" cp = np\n",
"\n",
"\n",
"\n",
"def transform_wavelet_freq_helper_CuPy(data: cp.ndarray, Nf: int, Nt: int, phif: cp.ndarray) -> cp.ndarray:\n",
" if not cupy_available:\n",
" return None\n",
"\n",
" wave = cp.zeros((Nt, Nf), dtype=cp.complex64)\n",
" f_bins = cp.arange(Nf)\n",
"\n",
" # Base indices\n",
" i_base = Nt // 2\n",
" jj_base = f_bins * Nt // 2\n",
"\n",
" # Initial values\n",
" initial_values = cp.where(\n",
" (f_bins == 0) | (f_bins == Nf),\n",
" phif[0] * data[f_bins * Nt // 2] / 2.0,\n",
" phif[0] * data[f_bins * Nt // 2]\n",
" )\n",
"\n",
" DX = cp.zeros((Nf, Nt), dtype=cp.complex64)\n",
" DX[:, Nt // 2] = initial_values\n",
"\n",
" j_range = cp.arange(1 - Nt // 2, Nt // 2)\n",
" j = cp.abs(j_range)\n",
" i = i_base + j_range\n",
"\n",
" cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0)\n",
" cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0)\n",
" cond3 = j[None, :] == 0\n",
"\n",
" jj = jj_base[:, None] + j_range[None, :]\n",
" val = cp.where(cond1 | cond2, 0.0, phif[j] * data[jj])\n",
" DX[:, i] = cp.where(cond3, DX[:, i], val)\n",
"\n",
" # Vectorized ifft using CuPy\n",
" DX_trans = cp.fft.ifft(DX, axis=1)\n",
"\n",
" # Vectorized __fill_wave_2\n",
" n_range = cp.arange(Nt)\n",
" cond1 = (n_range[:, None] + f_bins[None, :]) % 2 == 1\n",
" cond2 = f_bins % 2 == 1\n",
"\n",
" real_part = cp.where(cond2, -cp.imag(DX_trans), cp.real(DX_trans))\n",
" imag_part = cp.where(cond2, cp.real(DX_trans), cp.imag(DX_trans))\n",
"\n",
" wave = cp.where(cond1, imag_part, real_part)\n",
"\n",
" # Special cases for f_bin 0 and Nf\n",
" wave[::2, 0] = cp.real(DX_trans[0, ::2] * cp.sqrt(2))\n",
" wave[1::2, -1] = cp.real(DX_trans[-1, ::2] * cp.sqrt(2))\n",
"\n",
" return wave.T"
],
"metadata": {
"id": "EWQ8ugU6LGwe"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"\n",
"KERNEL_CODE = r'''\n",
"extern \"C\" __global__\n",
"void wavelet_kernel(\n",
" const float* data, const float* phif, int Nt, int Nf, float* DX_real, float* DX_imag) {\n",
"\n",
" int f = blockIdx.x * blockDim.x + threadIdx.x; // Frequency index\n",
" int t = blockIdx.y * blockDim.y + threadIdx.y; // Time index\n",
"\n",
" if (f < Nf && t < Nt) {\n",
" int i_base = Nt / 2;\n",
" int jj_base = f * Nt / 2;\n",
"\n",
" // Initial values logic\n",
" if (t == Nt / 2) {\n",
" float init_val;\n",
" if (f == 0 || f == Nf - 1) {\n",
" init_val = phif[0] * data[jj_base] / 2.0;\n",
" } else {\n",
" init_val = phif[0] * data[jj_base];\n",
" }\n",
" DX_real[f * Nt + t] = init_val;\n",
" DX_imag[f * Nt + t] = 0.0f; // Initialize imaginary part to zero\n",
" }\n",
" }\n",
"}\n",
"'''\n",
"\n",
"\n",
"if cupy_available:\n",
" wavelet_kernel = cp.RawKernel(KERNEL_CODE, 'wavelet_kernel')\n",
"\n",
"import cupy as cp\n",
"\n",
"# Define the custom CUDA kernel\n",
"\n",
"\n",
"def transform_wavelet_freq_helper_CUDA(data: cp.ndarray, Nf: int, Nt: int, phif: cp.ndarray) -> cp.ndarray:\n",
" if not cupy_available:\n",
" return None\n",
" # Prepare output arrays\n",
" wave = cp.zeros((Nt, Nf), dtype=cp.complex64)\n",
"\n",
" # Prepare DX as separate real and imaginary parts\n",
" DX_real = cp.zeros((Nf, Nt), dtype=cp.float32)\n",
" DX_imag = cp.zeros((Nf, Nt), dtype=cp.float32)\n",
"\n",
" # Launch the CUDA kernel to populate DX\n",
" threads_per_block = (16, 16)\n",
" blocks_per_grid = ((Nf + threads_per_block[0] - 1) // threads_per_block[0],\n",
" (Nt + threads_per_block[1] - 1) // threads_per_block[1])\n",
"\n",
" # Unpack the blocks_per_grid tuple\n",
" wavelet_kernel(\n",
" blocks_per_grid, threads_per_block,\n",
" (data.data.ptr, phif.data.ptr, Nt, Nf, DX_real.data.ptr, DX_imag.data.ptr)\n",
" )\n",
"\n",
" # Perform inverse FFT on DX using CuPy's optimized FFT\n",
" DX = DX_real + 1j * DX_imag\n",
" DX_trans = cp.fft.ifft(DX, axis=1)\n",
"\n",
" # Fill the wave array using CuPy\n",
" f_bins = cp.arange(Nf)\n",
" n_range = cp.arange(Nt)\n",
"\n",
" # Condition logic (this part may not require CUDA kernel)\n",
" cond1 = (n_range[:, None] + f_bins[None, :]) % 2 == 1\n",
" cond2 = f_bins % 2 == 1\n",
"\n",
" real_part = cp.where(cond2, -cp.imag(DX_trans), cp.real(DX_trans))\n",
" imag_part = cp.where(cond2, cp.real(DX_trans), cp.imag(DX_trans))\n",
"\n",
" wave = cp.where(cond1, imag_part, real_part)\n",
"\n",
" # Special cases for f_bin 0 and Nf\n",
" wave[::2, 0] = cp.real(DX_trans[0, ::2] * cp.sqrt(2))\n",
" wave[1::2, -1] = cp.real(DX_trans[-1, ::2] * cp.sqrt(2))\n",
"\n",
" return wave.T\n",
"\n"
],
"metadata": {
"id": "TNZzHj2zMa2f",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 859
},
"outputId": "f2ebf17a-1bed-40cc-eb53-1d85cb7d7a43"
},
"execution_count": 5,
"outputs": [
{
"output_type": "error",
"ename": "ImportError",
"evalue": "\n================================================================\nFailed to import CuPy.\n\nIf you installed CuPy via wheels (cupy-cudaXXX or cupy-rocm-X-X), make sure that the package matches with the version of CUDA or ROCm installed.\n\nOn Linux, you may need to set LD_LIBRARY_PATH environment variable depending on how you installed CUDA/ROCm.\nOn Windows, try setting CUDA_PATH environment variable.\n\nCheck the Installation Guide for details:\n https://docs.cupy.dev/en/latest/install.html\n\nOriginal error:\n ImportError: libcuda.so.1: cannot open shared object file: No such file or directory\n================================================================\n",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/cupy/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mcupy\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0m_core\u001b[0m \u001b[0;31m# NOQA\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/cupy/_core/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mcupy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_core\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcore\u001b[0m \u001b[0;31m# NOQA\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mcupy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_core\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mfusion\u001b[0m \u001b[0;31m# NOQA\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mImportError\u001b[0m: libcuda.so.1: cannot open shared object file: No such file or directory",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-5-69da24a5b2b5>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mcupy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mcp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m# Define the custom CUDA kernel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m wavelet_kernel = cp.RawKernel(r'''\n\u001b[1;32m 5\u001b[0m \u001b[0mextern\u001b[0m \u001b[0;34m\"C\"\u001b[0m \u001b[0m__global__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/cupy/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mcupy\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0m_core\u001b[0m \u001b[0;31m# NOQA\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m raise ImportError(f'''\n\u001b[0m\u001b[1;32m 20\u001b[0m \u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0m_environment\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_diagnose_import_error\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mImportError\u001b[0m: \n================================================================\nFailed to import CuPy.\n\nIf you installed CuPy via wheels (cupy-cudaXXX or cupy-rocm-X-X), make sure that the package matches with the version of CUDA or ROCm installed.\n\nOn Linux, you may need to set LD_LIBRARY_PATH environment variable depending on how you installed CUDA/ROCm.\nOn Windows, try setting CUDA_PATH environment variable.\n\nCheck the Installation Guide for details:\n https://docs.cupy.dev/en/latest/install.html\n\nOriginal error:\n ImportError: libcuda.so.1: cannot open shared object file: No such file or directory\n================================================================\n",
"",
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n"
],
"errorDetails": {
"actions": [
{
"action": "open_url",
"actionText": "Open Examples",
"url": "/notebooks/snippets/importing_libraries.ipynb"
}
]
}
}
]
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"# collect plotting data\n",
"Nf = Nt = 2 ** 6\n",
"yf, phif = simulate_data(Nf)\n",
"wave = transform_wavelet_freq_helper_numba(yf, Nf, Nt, phif)\n",
"jax_yf = jnp.array(yf)\n",
"jax_phif = jnp.array(phif)\n",
"wave_jax = transform_wavelet_freq_helper_JAX(jax_yf, Nf, Nt, jax_phif)\n",
"nplots = 2\n",
"\n",
"if cupy_available:\n",
" cupy_yf = cp.array(yf)\n",
" cupy_phif = cp.array(phif)\n",
" wave_cupy = transform_wavelet_freq_helper_CuPy(cupy_yf, Nf, Nt, cupy_phif).get()\n",
" wave_CUDA = transform_wavelet_freq_helper_CUDA(cupy_yf, Nf, Nt, cupy_phif).get()\n",
" nplots += 2\n",
"\n",
"# render plot\n",
"fig, ax = plt.subplots(1, nplots, figsize=(5, 5), sharex=True, sharey=True)\n",
"ax[0].imshow(np.abs(np.rot90(wave)))\n",
"ax[0].set_title(\"Numba\")\n",
"ax[1].imshow(np.abs(np.rot90(wave_jax)))\n",
"ax[1].set_title(\"Jax\")\n",
"if cupy_available:\n",
" ax[2].imshow(np.abs(np.rot90(wave_cupy)))\n",
" ax[2].set_title(\"Cupy\")\n",
" ax[3].imshow(np.abs(np.rot90(wave_CUDA)))\n",
" ax[3].set_title(\"CUDA\")\n",
"\n",
"for a in ax:\n",
" a.set_xticks([])\n",
" a.set_yticks([])\n",
" # a.set_ylim(120, 70)\n",
"plt.tight_layout()\n",
"plt.show()"
],
"metadata": {
"id": "Tx1yiDeFh7Yl"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Timing comparisons"
],
"metadata": {
"id": "v8MIBbNHmndU"
}
},
{
"cell_type": "code",
"source": [
"%timeit transform_wavelet_freq_helper_numba(yf, Nf, Nt, phif)"
],
"metadata": {
"id": "yT8KZlsZjFZq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"%timeit transform_wavelet_freq_helper_JAX(jax_yf, Nf, Nt, jax_phif)"
],
"metadata": {
"id": "X3ObJ_BrjIHC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"%timeit transform_wavelet_freq_helper_CuPy(cupy_yf, Nf, Nt, cupy_phif)"
],
"metadata": {
"id": "UdHonCg3OaT9"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import time\n",
"from tqdm.auto import tqdm\n",
"import pandas as pd\n",
"import numpy as np\n",
"import jax.numpy as jnp\n",
"\n",
"\n",
"def time_function(func, *args):\n",
" try:\n",
" t0 = time.time()\n",
" func(*args)\n",
" return time.time() - t0\n",
" except Exception:\n",
" return np.nan\n",
"\n",
"\n",
"def run_numba_timings(Nf=2**6, Nt=None, n_rep=10):\n",
" \"\"\"Run timing for Numba implementation.\"\"\"\n",
" Nt = Nt or Nf\n",
" yf, phif = simulate_data(Nf)\n",
" # Warmup\n",
" transform_wavelet_freq_helper_numba(yf, Nf, Nt, phif)\n",
"\n",
" times = np.zeros(n_rep)\n",
" for i in range(n_rep):\n",
" times[i] = time_function(transform_wavelet_freq_helper_numba, yf, Nf, Nt, phif)\n",
" return times\n",
"\n",
"def run_jax_timings(Nf=2**6, Nt=None, n_rep=10):\n",
" \"\"\"Run timing for JAX implementation.\"\"\"\n",
" Nt = Nt or Nf\n",
" yf, phif = simulate_data(Nf)\n",
" jax_yf, jax_phif = jnp.array(yf), jnp.array(phif)\n",
" # Warmup\n",
" transform_wavelet_freq_helper_JAX(jax_yf, Nf, Nt, jax_phif)\n",
"\n",
" times = np.zeros(n_rep)\n",
" for i in range(n_rep):\n",
" times[i] = time_function(transform_wavelet_freq_helper_JAX, jax_yf, Nf, Nt, jax_phif)\n",
" return times\n",
"\n",
"def run_cupy_timings(Nf=2**6, Nt=None, n_rep=10):\n",
" \"\"\"Run timing for CuPy implementation.\"\"\"\n",
" Nt = Nt or Nf\n",
"\n",
" if (Nf >= 8192 # Cupy seems to run out of mem here.\n",
" or not cupy_available):\n",
" return np.array([np.nan] * n_rep)\n",
"\n",
" yf, phif = simulate_data(Nf)\n",
" cupy_yf, cupy_phif = cp.array(yf), cp.array(phif)\n",
" # Warmup\n",
" transform_wavelet_freq_helper_CuPy(cupy_yf, Nf, Nt, cupy_phif)\n",
"\n",
" times = np.zeros(n_rep)\n",
" for i in range(n_rep):\n",
" times[i] = time_function(transform_wavelet_freq_helper_CuPy, cupy_yf, Nf, Nt, cupy_phif)\n",
" return times\n",
"\n",
"def run_experiments(min_n=5, max_n=10, n_rep=10):\n",
" \"\"\"\n",
" Run experiments for different sizes and collect runtime data.\n",
"\n",
" Parameters:\n",
" - min_n: Minimum power of 2 for the size of Nf.\n",
" - max_n: Maximum power of 2 for the size of Nf.\n",
" - n_rep: Number of repetitions for each size.\n",
"\n",
" Returns:\n",
" - pd.DataFrame: DataFrame containing runtimes for different methods and sizes.\n",
" \"\"\"\n",
" jax_label = f\"jax[{DEVICE}]_times\"\n",
" data = {\n",
" \"Ns\": [],\n",
" \"numba_times\": [],\n",
" jax_label: [],\n",
" \"cupy_times\": []\n",
" }\n",
"\n",
" Nfs = [2**i for i in range(min_n, max_n)]\n",
"\n",
" pbar = tqdm(Nfs, desc=\"Running experiments\")\n",
" for Nf in pbar:\n",
" pbar.set_postfix({\"Nf\": Nf})\n",
" numba_times = run_numba_timings(Nf=Nf, n_rep=n_rep)\n",
" jax_times = run_jax_timings(Nf=Nf, n_rep=n_rep)\n",
" cupy_times = run_cupy_timings(Nf=Nf, n_rep=n_rep)\n",
"\n",
" data[\"Ns\"].extend([Nf] * n_rep)\n",
" data[\"numba_times\"].extend(numba_times)\n",
" data[jax_label].extend(jax_times)\n",
" data[\"cupy_times\"].extend(cupy_times)\n",
"\n",
" return pd.DataFrame(data)\n",
"\n",
"\n",
"\n",
"# Running the experiments and collecting data\n",
"data = run_experiments(max_n=14, n_rep=5)\n",
"data.to_csv(\"runtimes.csv\", index=False)"
],
"metadata": {
"id": "rXA0cHe5gahJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import os\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"def plot_runtime_vs_size(data, column_label='numba', color='blue', ax=None):\n",
" \"\"\"\n",
" Plots the runtime vs size graph with error bands for a given column in the dataset.\n",
"\n",
" Parameters:\n",
" - data: DataFrame containing the data to be plotted.\n",
" - column_label: Label for the column prefix (e.g., 'numba', 'jax', 'cupy') to plot.\n",
" - color: Color for the plot lines and fill area (default: 'blue').\n",
" - ax: Optional matplotlib axis to plot on. If None, a new figure and axis are created.\n",
"\n",
" Returns:\n",
" - ax: The axis with the plot.\n",
" \"\"\"\n",
"\n",
" df = data.copy()\n",
" grouped = df.groupby('Ns').agg(\n",
" mean=(f'{column_label}_times', 'mean'),\n",
" sem=(f'{column_label}_times', 'sem')\n",
" ).reset_index()\n",
" grouped['exponent'] = np.log2(grouped['Ns']).astype(int)\n",
"\n",
" # Create an axis if not provided\n",
" if ax is None:\n",
" fig, ax = plt.subplots(figsize=(4, 3))\n",
"\n",
"\n",
" # Plotting\n",
" ax.fill_between(grouped['exponent'],\n",
" grouped['mean'] - grouped['sem'],\n",
" grouped['mean'] + grouped['sem'],\n",
" color=color, alpha=0.3, label=f'{column_label.capitalize()}', lw=0)\n",
" ax.plot(grouped['exponent'], grouped['mean'], color=color, marker='')\n",
"\n",
" # Set labels, legend\n",
" ax.set_xlabel('$N_fxN_t$', fontsize=12)\n",
" ax.set_ylabel('Runtime [s]', fontsize=12)\n",
" ax.legend(fontsize=10, frameon=False)\n",
"\n",
" ax.grid(True, which=\"both\", linestyle=\"-\", alpha=0.2)\n",
" fmt = '$2^{e}x2^{e}$'\n",
" ax.set_xticks(grouped['exponent'])\n",
" ax.set_xticklabels([fmt.replace(\"e\", str(exp)) for exp in grouped['exponent']])\n",
" ax.tick_params(axis='x', rotation=45)\n",
" ax.set_yscale('log')\n",
" plt.tight_layout()\n",
" return ax\n",
"\n",
"if 'runtimes_gpu.csv' in os.listdir():\n",
" df_gpu = pd.read_csv(\"runtimes_gpu.csv\")\n",
" df_tpu = pd.read_csv(\"runtimes_tpu.csv\")\n",
" df = df_gpu.copy()\n",
" df['jax[TPU v2]_times'] = df_tpu['jax[TPU v2]_times']\n",
"else:\n",
" df = pd.read_csv(\"runtimes.csv\")\n",
"fig, ax = plt.subplots(figsize=(4, 3))\n",
"labels = [i.split(\"_times\")[0] for i in df.columns.values if 'times' in i]\n",
"for i, l in enumerate(labels):\n",
" plot_runtime_vs_size(df, column_label=l, color=f\"C{i}\", ax=ax)"
],
"metadata": {
"id": "bNoJIMuHFftq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Note: we've used lower precision data for JAX + Cupy (this can be changed)."
],
"metadata": {
"id": "m5UPtRxf32cz"
}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment