Last active
October 4, 2024 07:52
-
-
Save avivajpeyi/b0a99b3f54b6841fb37e60af312b58a4 to your computer and use it in GitHub Desktop.
wdm_transform.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyO+z+jbBgidS18Pn0bqY3aQ", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/avivajpeyi/b0a99b3f54b6841fb37e60af312b58a4/wdm_transform.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Comparing JAX and Numba WDM Transforms\n", | |
"\n", | |
"\n", | |
"## Common functions" | |
], | |
"metadata": { | |
"id": "H6YcKEr5mHIV" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"id": "0gJu4ONiQ6sB", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "b6820784-3486-4621-d6af-4376b837f747" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Running JAX on cpu\n" | |
] | |
} | |
], | |
"source": [ | |
"\"\"\" COMMON FUNCS \"\"\"\n", | |
"import numpy as np\n", | |
"from numpy import fft\n", | |
"from scipy.special import betainc\n", | |
"from scipy.signal import chirp\n", | |
"from typing import List, Tuple\n", | |
"import jax\n", | |
"jax.config.update('jax_enable_x64', False)\n", | |
"\n", | |
"DEVICE = jax.devices()[0].device_kind\n", | |
"print(f\"Running JAX on {DEVICE}\")\n", | |
"\n", | |
"PI = np.pi\n", | |
"FREQ_RANGE = [20, 100]\n", | |
"\n", | |
"def phitilde_vec_norm(Nf: int, Nt: int, dt: float, d: float) -> np.ndarray:\n", | |
" ND = Nf * Nt\n", | |
" omegas = 2 * np.pi / ND * np.arange(0, Nt // 2 + 1)\n", | |
" u_phit = phitilde_vec(omegas, Nf, dt, d)\n", | |
" normalising_factor = np.pi ** (-1 / 2) # Ollie's normalising factor\n", | |
" return u_phit / (normalising_factor)\n", | |
"\n", | |
"def phitilde_vec(omega: np.ndarray, Nf: int, dt: float, d: float = 4.0) -> np.ndarray:\n", | |
" \"\"\"Compute phi_tilde(omega_i) array.\"\"\"\n", | |
" dF = 1.0 / (2 * Nf * dt)\n", | |
" dOmega = 2 * PI * dF\n", | |
" inverse_sqrt_dOmega = 1.0 / np.sqrt(dOmega)\n", | |
" A, B = dOmega / 4, 3 * dOmega / 4\n", | |
" if B <= 0:\n", | |
" raise ValueError(\"B must be greater than 0\")\n", | |
" phi = np.full_like(omega, inverse_sqrt_dOmega)\n", | |
" mask = (A <= np.abs(omega)) & (np.abs(omega) < A + B)\n", | |
" phi[mask] *= np.cos((PI / 2.0) * __nu_d(omega[mask], A, B, d))\n", | |
" return phi\n", | |
"\n", | |
"def __nu_d(omega: np.ndarray, A: float, B: float, d: float = 4.0) -> np.ndarray:\n", | |
" \"\"\"Compute the normalized incomplete beta function.\"\"\"\n", | |
" x = (np.abs(omega) - A) / B\n", | |
" return betainc(d, d, x) / betainc(d, d, 1)\n", | |
"\n", | |
"\n", | |
"def simulate_data(Nf):\n", | |
" # assert Nf is power of 2\n", | |
" assert Nf & (Nf - 1) == 0, \"Nf must be a power of 2\"\n", | |
" fs = 512\n", | |
" dt = 1 / fs\n", | |
" Nt = Nf\n", | |
" mult = 16\n", | |
" nx = 4.0\n", | |
" ND = Nt * Nf\n", | |
" t = np.arange(0, ND) * dt\n", | |
" y = chirp(t, f0=FREQ_RANGE[0], f1=FREQ_RANGE[1], t1=t[-1], method=\"quadratic\")\n", | |
" phif = phitilde_vec_norm(Nf, Nt, dt=dt, d=nx)\n", | |
" yf = fft.fft(y)[:ND//2+1]\n", | |
" return yf, phif\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## NUMBA Transform" | |
], | |
"metadata": { | |
"id": "4JhI5f_A3p2V" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"\"\"\"NUMBA VERSION\"\"\"\n", | |
"import numpy as np\n", | |
"from numba import njit\n", | |
"from numpy import fft\n", | |
"\n", | |
"\n", | |
"def transform_wavelet_freq_helper_numba(\n", | |
" data: np.ndarray, Nf: int, Nt: int, phif: np.ndarray\n", | |
") -> np.ndarray:\n", | |
" \"\"\"helper to do the wavelet transform using the fast wavelet domain transform\"\"\"\n", | |
" wave = np.zeros((Nt, Nf)) # wavelet wavepacket transform of the signal\n", | |
"\n", | |
" DX = np.zeros(Nt, dtype=np.complex128)\n", | |
" freq_strain = data.copy() # Convert\n", | |
" for f_bin in range(0, Nf + 1):\n", | |
" __fill_wave_1_numba(f_bin, Nt, Nf, DX, freq_strain, phif)\n", | |
" DX_trans = fft.ifft(\n", | |
" DX, Nt\n", | |
" ) # A fix because numba doesn't support np.fft\n", | |
" __fill_wave_2_numba(f_bin, DX_trans, wave, Nt, Nf)\n", | |
"\n", | |
" return wave\n", | |
"\n", | |
"\n", | |
"@njit()\n", | |
"def __fill_wave_1_numba(\n", | |
" f_bin: int,\n", | |
" Nt: int,\n", | |
" Nf: int,\n", | |
" DX: np.ndarray,\n", | |
" data: np.ndarray,\n", | |
" phif: np.ndarray,\n", | |
") -> None:\n", | |
" \"\"\"helper for assigning DX in the main loop\"\"\"\n", | |
" i_base = Nt // 2\n", | |
" jj_base = f_bin * Nt // 2\n", | |
"\n", | |
" if f_bin == 0 or f_bin == Nf:\n", | |
" # NOTE this term appears to be needed to recover correct constant (at least for m=0), but was previously missing\n", | |
" DX[Nt // 2] = phif[0] * data[f_bin * Nt // 2] / 2.0\n", | |
" else:\n", | |
" DX[Nt // 2] = phif[0] * data[f_bin * Nt // 2]\n", | |
"\n", | |
" for jj in range(jj_base + 1 - Nt // 2, jj_base + Nt // 2):\n", | |
" j = np.abs(jj - jj_base)\n", | |
" i = i_base - jj_base + jj\n", | |
" if f_bin == Nf and jj > jj_base:\n", | |
" DX[i] = 0.0\n", | |
" elif f_bin == 0 and jj < jj_base:\n", | |
" DX[i] = 0.0\n", | |
" elif j == 0:\n", | |
" continue\n", | |
" else:\n", | |
" DX[i] = phif[j] * data[jj]\n", | |
"\n", | |
"\n", | |
"@njit()\n", | |
"def __fill_wave_2_numba(\n", | |
" f_bin: int, DX_trans: np.ndarray, wave: np.ndarray, Nt: int, Nf: int\n", | |
") -> None:\n", | |
" if f_bin == 0:\n", | |
" # half of lowest and highest frequency bin pixels are redundant, so store them in even and odd components of f_bin=0 respectively\n", | |
" for n in range(0, Nt, 2):\n", | |
" wave[n, 0] = np.real(DX_trans[n] * np.sqrt(2))\n", | |
" elif f_bin == Nf:\n", | |
" for n in range(0, Nt, 2):\n", | |
" wave[n + 1, 0] = np.real(DX_trans[n] * np.sqrt(2))\n", | |
" else:\n", | |
" for n in range(0, Nt):\n", | |
" if f_bin % 2:\n", | |
" if (n + f_bin) % 2:\n", | |
" wave[n, f_bin] = -np.imag(DX_trans[n])\n", | |
" else:\n", | |
" wave[n, f_bin] = np.real(DX_trans[n])\n", | |
" else:\n", | |
" if (n + f_bin) % 2:\n", | |
" wave[n, f_bin] = np.imag(DX_trans[n])\n", | |
" else:\n", | |
" wave[n, f_bin] = np.real(DX_trans[n])\n", | |
"\n", | |
"\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "RkJzTw0t3pPw" | |
}, | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## JAX WDM transform" | |
], | |
"metadata": { | |
"id": "u9mBzg9Vmhgm" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"\"\"\"JAX VERSION\"\"\"\n", | |
"import jax\n", | |
"import jax.numpy as jnp\n", | |
"from functools import partial\n", | |
"from jax import jit\n", | |
"from jax.numpy.fft import ifft\n", | |
"\n", | |
"@partial(jit, static_argnames=('Nf', 'Nt'))\n", | |
"def transform_wavelet_freq_helper_JAX(\n", | |
" data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray\n", | |
") -> jnp.ndarray:\n", | |
" wave = jnp.zeros((Nt, Nf))\n", | |
" f_bins = jnp.arange(Nf)\n", | |
"\n", | |
" i_base = Nt // 2\n", | |
" jj_base = f_bins * Nt // 2\n", | |
"\n", | |
" initial_values = jnp.where(\n", | |
" (f_bins == 0) | (f_bins == Nf),\n", | |
" phif[0] * data[f_bins * Nt // 2] / 2.0,\n", | |
" phif[0] * data[f_bins * Nt // 2]\n", | |
" )\n", | |
"\n", | |
" DX = jnp.zeros((Nf, Nt), dtype=jnp.complex64)\n", | |
" DX = DX.at[:, Nt // 2].set(initial_values)\n", | |
"\n", | |
" j_range = jnp.arange(1 - Nt // 2, Nt // 2)\n", | |
" j = jnp.abs(j_range)\n", | |
" i = i_base + j_range\n", | |
"\n", | |
" cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0)\n", | |
" cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0)\n", | |
" cond3 = j[None, :] == 0\n", | |
"\n", | |
" jj = jj_base[:, None] + j_range[None, :]\n", | |
" val = jnp.where(cond1 | cond2, 0.0, phif[j] * data[jj])\n", | |
" DX = DX.at[:, i].set(jnp.where(cond3, DX[:, i], val))\n", | |
"\n", | |
" # Vectorized ifft\n", | |
" DX_trans = ifft(DX, axis=1)\n", | |
"\n", | |
" # Vectorized __fill_wave_2_jax\n", | |
" n_range = jnp.arange(Nt)\n", | |
" cond1 = (n_range[:, None] + f_bins[None, :]) % 2 == 1\n", | |
" cond2 = f_bins % 2 == 1\n", | |
"\n", | |
" real_part = jnp.where(cond2, -jnp.imag(DX_trans), jnp.real(DX_trans))\n", | |
" imag_part = jnp.where(cond2, jnp.real(DX_trans), jnp.imag(DX_trans))\n", | |
"\n", | |
" wave = jnp.where(cond1, imag_part, real_part)\n", | |
"\n", | |
" # Special cases for f_bin 0 and Nf\n", | |
" wave = wave.at[::2, 0].set(jnp.real(DX_trans[0, ::2] * jnp.sqrt(2)))\n", | |
" wave = wave.at[1::2, -1].set(jnp.real(DX_trans[-1, ::2] * jnp.sqrt(2)))\n", | |
"\n", | |
" return wave.T\n" | |
], | |
"metadata": { | |
"id": "Cf4wt-8dX9Xb" | |
}, | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Cupy version" | |
], | |
"metadata": { | |
"id": "b_5FzcvgLHr2" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"try:\n", | |
" import cupy as cp\n", | |
" cupy_available = True\n", | |
"except:\n", | |
" cupy_available = False\n", | |
" cp = np\n", | |
"\n", | |
"\n", | |
"\n", | |
"def transform_wavelet_freq_helper_CuPy(data: cp.ndarray, Nf: int, Nt: int, phif: cp.ndarray) -> cp.ndarray:\n", | |
" if not cupy_available:\n", | |
" return None\n", | |
"\n", | |
" wave = cp.zeros((Nt, Nf), dtype=cp.complex64)\n", | |
" f_bins = cp.arange(Nf)\n", | |
"\n", | |
" # Base indices\n", | |
" i_base = Nt // 2\n", | |
" jj_base = f_bins * Nt // 2\n", | |
"\n", | |
" # Initial values\n", | |
" initial_values = cp.where(\n", | |
" (f_bins == 0) | (f_bins == Nf),\n", | |
" phif[0] * data[f_bins * Nt // 2] / 2.0,\n", | |
" phif[0] * data[f_bins * Nt // 2]\n", | |
" )\n", | |
"\n", | |
" DX = cp.zeros((Nf, Nt), dtype=cp.complex64)\n", | |
" DX[:, Nt // 2] = initial_values\n", | |
"\n", | |
" j_range = cp.arange(1 - Nt // 2, Nt // 2)\n", | |
" j = cp.abs(j_range)\n", | |
" i = i_base + j_range\n", | |
"\n", | |
" cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0)\n", | |
" cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0)\n", | |
" cond3 = j[None, :] == 0\n", | |
"\n", | |
" jj = jj_base[:, None] + j_range[None, :]\n", | |
" val = cp.where(cond1 | cond2, 0.0, phif[j] * data[jj])\n", | |
" DX[:, i] = cp.where(cond3, DX[:, i], val)\n", | |
"\n", | |
" # Vectorized ifft using CuPy\n", | |
" DX_trans = cp.fft.ifft(DX, axis=1)\n", | |
"\n", | |
" # Vectorized __fill_wave_2\n", | |
" n_range = cp.arange(Nt)\n", | |
" cond1 = (n_range[:, None] + f_bins[None, :]) % 2 == 1\n", | |
" cond2 = f_bins % 2 == 1\n", | |
"\n", | |
" real_part = cp.where(cond2, -cp.imag(DX_trans), cp.real(DX_trans))\n", | |
" imag_part = cp.where(cond2, cp.real(DX_trans), cp.imag(DX_trans))\n", | |
"\n", | |
" wave = cp.where(cond1, imag_part, real_part)\n", | |
"\n", | |
" # Special cases for f_bin 0 and Nf\n", | |
" wave[::2, 0] = cp.real(DX_trans[0, ::2] * cp.sqrt(2))\n", | |
" wave[1::2, -1] = cp.real(DX_trans[-1, ::2] * cp.sqrt(2))\n", | |
"\n", | |
" return wave.T" | |
], | |
"metadata": { | |
"id": "EWQ8ugU6LGwe" | |
}, | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"\n", | |
"KERNEL_CODE = r'''\n", | |
"extern \"C\" __global__\n", | |
"void wavelet_kernel(\n", | |
" const float* data, const float* phif, int Nt, int Nf, float* DX_real, float* DX_imag) {\n", | |
"\n", | |
" int f = blockIdx.x * blockDim.x + threadIdx.x; // Frequency index\n", | |
" int t = blockIdx.y * blockDim.y + threadIdx.y; // Time index\n", | |
"\n", | |
" if (f < Nf && t < Nt) {\n", | |
" int i_base = Nt / 2;\n", | |
" int jj_base = f * Nt / 2;\n", | |
"\n", | |
" // Initial values logic\n", | |
" if (t == Nt / 2) {\n", | |
" float init_val;\n", | |
" if (f == 0 || f == Nf - 1) {\n", | |
" init_val = phif[0] * data[jj_base] / 2.0;\n", | |
" } else {\n", | |
" init_val = phif[0] * data[jj_base];\n", | |
" }\n", | |
" DX_real[f * Nt + t] = init_val;\n", | |
" DX_imag[f * Nt + t] = 0.0f; // Initialize imaginary part to zero\n", | |
" }\n", | |
" }\n", | |
"}\n", | |
"'''\n", | |
"\n", | |
"\n", | |
"if cupy_available:\n", | |
" wavelet_kernel = cp.RawKernel(KERNEL_CODE, 'wavelet_kernel')\n", | |
"\n", | |
"import cupy as cp\n", | |
"\n", | |
"# Define the custom CUDA kernel\n", | |
"\n", | |
"\n", | |
"def transform_wavelet_freq_helper_CUDA(data: cp.ndarray, Nf: int, Nt: int, phif: cp.ndarray) -> cp.ndarray:\n", | |
" if not cupy_available:\n", | |
" return None\n", | |
" # Prepare output arrays\n", | |
" wave = cp.zeros((Nt, Nf), dtype=cp.complex64)\n", | |
"\n", | |
" # Prepare DX as separate real and imaginary parts\n", | |
" DX_real = cp.zeros((Nf, Nt), dtype=cp.float32)\n", | |
" DX_imag = cp.zeros((Nf, Nt), dtype=cp.float32)\n", | |
"\n", | |
" # Launch the CUDA kernel to populate DX\n", | |
" threads_per_block = (16, 16)\n", | |
" blocks_per_grid = ((Nf + threads_per_block[0] - 1) // threads_per_block[0],\n", | |
" (Nt + threads_per_block[1] - 1) // threads_per_block[1])\n", | |
"\n", | |
" # Unpack the blocks_per_grid tuple\n", | |
" wavelet_kernel(\n", | |
" blocks_per_grid, threads_per_block,\n", | |
" (data.data.ptr, phif.data.ptr, Nt, Nf, DX_real.data.ptr, DX_imag.data.ptr)\n", | |
" )\n", | |
"\n", | |
" # Perform inverse FFT on DX using CuPy's optimized FFT\n", | |
" DX = DX_real + 1j * DX_imag\n", | |
" DX_trans = cp.fft.ifft(DX, axis=1)\n", | |
"\n", | |
" # Fill the wave array using CuPy\n", | |
" f_bins = cp.arange(Nf)\n", | |
" n_range = cp.arange(Nt)\n", | |
"\n", | |
" # Condition logic (this part may not require CUDA kernel)\n", | |
" cond1 = (n_range[:, None] + f_bins[None, :]) % 2 == 1\n", | |
" cond2 = f_bins % 2 == 1\n", | |
"\n", | |
" real_part = cp.where(cond2, -cp.imag(DX_trans), cp.real(DX_trans))\n", | |
" imag_part = cp.where(cond2, cp.real(DX_trans), cp.imag(DX_trans))\n", | |
"\n", | |
" wave = cp.where(cond1, imag_part, real_part)\n", | |
"\n", | |
" # Special cases for f_bin 0 and Nf\n", | |
" wave[::2, 0] = cp.real(DX_trans[0, ::2] * cp.sqrt(2))\n", | |
" wave[1::2, -1] = cp.real(DX_trans[-1, ::2] * cp.sqrt(2))\n", | |
"\n", | |
" return wave.T\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "TNZzHj2zMa2f", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 859 | |
}, | |
"outputId": "f2ebf17a-1bed-40cc-eb53-1d85cb7d7a43" | |
}, | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "error", | |
"ename": "ImportError", | |
"evalue": "\n================================================================\nFailed to import CuPy.\n\nIf you installed CuPy via wheels (cupy-cudaXXX or cupy-rocm-X-X), make sure that the package matches with the version of CUDA or ROCm installed.\n\nOn Linux, you may need to set LD_LIBRARY_PATH environment variable depending on how you installed CUDA/ROCm.\nOn Windows, try setting CUDA_PATH environment variable.\n\nCheck the Installation Guide for details:\n https://docs.cupy.dev/en/latest/install.html\n\nOriginal error:\n ImportError: libcuda.so.1: cannot open shared object file: No such file or directory\n================================================================\n", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/cupy/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mcupy\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0m_core\u001b[0m \u001b[0;31m# NOQA\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/cupy/_core/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mcupy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_core\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcore\u001b[0m \u001b[0;31m# NOQA\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mcupy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_core\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mfusion\u001b[0m \u001b[0;31m# NOQA\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mImportError\u001b[0m: libcuda.so.1: cannot open shared object file: No such file or directory", | |
"\nThe above exception was the direct cause of the following exception:\n", | |
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-5-69da24a5b2b5>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mcupy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mcp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m# Define the custom CUDA kernel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m wavelet_kernel = cp.RawKernel(r'''\n\u001b[1;32m 5\u001b[0m \u001b[0mextern\u001b[0m \u001b[0;34m\"C\"\u001b[0m \u001b[0m__global__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/cupy/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mcupy\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0m_core\u001b[0m \u001b[0;31m# NOQA\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m raise ImportError(f'''\n\u001b[0m\u001b[1;32m 20\u001b[0m \u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0m_environment\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_diagnose_import_error\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mImportError\u001b[0m: \n================================================================\nFailed to import CuPy.\n\nIf you installed CuPy via wheels (cupy-cudaXXX or cupy-rocm-X-X), make sure that the package matches with the version of CUDA or ROCm installed.\n\nOn Linux, you may need to set LD_LIBRARY_PATH environment variable depending on how you installed CUDA/ROCm.\nOn Windows, try setting CUDA_PATH environment variable.\n\nCheck the Installation Guide for details:\n https://docs.cupy.dev/en/latest/install.html\n\nOriginal error:\n ImportError: libcuda.so.1: cannot open shared object file: No such file or directory\n================================================================\n", | |
"", | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n" | |
], | |
"errorDetails": { | |
"actions": [ | |
{ | |
"action": "open_url", | |
"actionText": "Open Examples", | |
"url": "/notebooks/snippets/importing_libraries.ipynb" | |
} | |
] | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"# collect plotting data\n", | |
"Nf = Nt = 2 ** 6\n", | |
"yf, phif = simulate_data(Nf)\n", | |
"wave = transform_wavelet_freq_helper_numba(yf, Nf, Nt, phif)\n", | |
"jax_yf = jnp.array(yf)\n", | |
"jax_phif = jnp.array(phif)\n", | |
"wave_jax = transform_wavelet_freq_helper_JAX(jax_yf, Nf, Nt, jax_phif)\n", | |
"nplots = 2\n", | |
"\n", | |
"if cupy_available:\n", | |
" cupy_yf = cp.array(yf)\n", | |
" cupy_phif = cp.array(phif)\n", | |
" wave_cupy = transform_wavelet_freq_helper_CuPy(cupy_yf, Nf, Nt, cupy_phif).get()\n", | |
" wave_CUDA = transform_wavelet_freq_helper_CUDA(cupy_yf, Nf, Nt, cupy_phif).get()\n", | |
" nplots += 2\n", | |
"\n", | |
"# render plot\n", | |
"fig, ax = plt.subplots(1, nplots, figsize=(5, 5), sharex=True, sharey=True)\n", | |
"ax[0].imshow(np.abs(np.rot90(wave)))\n", | |
"ax[0].set_title(\"Numba\")\n", | |
"ax[1].imshow(np.abs(np.rot90(wave_jax)))\n", | |
"ax[1].set_title(\"Jax\")\n", | |
"if cupy_available:\n", | |
" ax[2].imshow(np.abs(np.rot90(wave_cupy)))\n", | |
" ax[2].set_title(\"Cupy\")\n", | |
" ax[3].imshow(np.abs(np.rot90(wave_CUDA)))\n", | |
" ax[3].set_title(\"CUDA\")\n", | |
"\n", | |
"for a in ax:\n", | |
" a.set_xticks([])\n", | |
" a.set_yticks([])\n", | |
" # a.set_ylim(120, 70)\n", | |
"plt.tight_layout()\n", | |
"plt.show()" | |
], | |
"metadata": { | |
"id": "Tx1yiDeFh7Yl" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Timing comparisons" | |
], | |
"metadata": { | |
"id": "v8MIBbNHmndU" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%timeit transform_wavelet_freq_helper_numba(yf, Nf, Nt, phif)" | |
], | |
"metadata": { | |
"id": "yT8KZlsZjFZq" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%timeit transform_wavelet_freq_helper_JAX(jax_yf, Nf, Nt, jax_phif)" | |
], | |
"metadata": { | |
"id": "X3ObJ_BrjIHC" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%timeit transform_wavelet_freq_helper_CuPy(cupy_yf, Nf, Nt, cupy_phif)" | |
], | |
"metadata": { | |
"id": "UdHonCg3OaT9" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import time\n", | |
"from tqdm.auto import tqdm\n", | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"import jax.numpy as jnp\n", | |
"\n", | |
"\n", | |
"def time_function(func, *args):\n", | |
" try:\n", | |
" t0 = time.time()\n", | |
" func(*args)\n", | |
" return time.time() - t0\n", | |
" except Exception:\n", | |
" return np.nan\n", | |
"\n", | |
"\n", | |
"def run_numba_timings(Nf=2**6, Nt=None, n_rep=10):\n", | |
" \"\"\"Run timing for Numba implementation.\"\"\"\n", | |
" Nt = Nt or Nf\n", | |
" yf, phif = simulate_data(Nf)\n", | |
" # Warmup\n", | |
" transform_wavelet_freq_helper_numba(yf, Nf, Nt, phif)\n", | |
"\n", | |
" times = np.zeros(n_rep)\n", | |
" for i in range(n_rep):\n", | |
" times[i] = time_function(transform_wavelet_freq_helper_numba, yf, Nf, Nt, phif)\n", | |
" return times\n", | |
"\n", | |
"def run_jax_timings(Nf=2**6, Nt=None, n_rep=10):\n", | |
" \"\"\"Run timing for JAX implementation.\"\"\"\n", | |
" Nt = Nt or Nf\n", | |
" yf, phif = simulate_data(Nf)\n", | |
" jax_yf, jax_phif = jnp.array(yf), jnp.array(phif)\n", | |
" # Warmup\n", | |
" transform_wavelet_freq_helper_JAX(jax_yf, Nf, Nt, jax_phif)\n", | |
"\n", | |
" times = np.zeros(n_rep)\n", | |
" for i in range(n_rep):\n", | |
" times[i] = time_function(transform_wavelet_freq_helper_JAX, jax_yf, Nf, Nt, jax_phif)\n", | |
" return times\n", | |
"\n", | |
"def run_cupy_timings(Nf=2**6, Nt=None, n_rep=10):\n", | |
" \"\"\"Run timing for CuPy implementation.\"\"\"\n", | |
" Nt = Nt or Nf\n", | |
"\n", | |
" if (Nf >= 8192 # Cupy seems to run out of mem here.\n", | |
" or not cupy_available):\n", | |
" return np.array([np.nan] * n_rep)\n", | |
"\n", | |
" yf, phif = simulate_data(Nf)\n", | |
" cupy_yf, cupy_phif = cp.array(yf), cp.array(phif)\n", | |
" # Warmup\n", | |
" transform_wavelet_freq_helper_CuPy(cupy_yf, Nf, Nt, cupy_phif)\n", | |
"\n", | |
" times = np.zeros(n_rep)\n", | |
" for i in range(n_rep):\n", | |
" times[i] = time_function(transform_wavelet_freq_helper_CuPy, cupy_yf, Nf, Nt, cupy_phif)\n", | |
" return times\n", | |
"\n", | |
"def run_experiments(min_n=5, max_n=10, n_rep=10):\n", | |
" \"\"\"\n", | |
" Run experiments for different sizes and collect runtime data.\n", | |
"\n", | |
" Parameters:\n", | |
" - min_n: Minimum power of 2 for the size of Nf.\n", | |
" - max_n: Maximum power of 2 for the size of Nf.\n", | |
" - n_rep: Number of repetitions for each size.\n", | |
"\n", | |
" Returns:\n", | |
" - pd.DataFrame: DataFrame containing runtimes for different methods and sizes.\n", | |
" \"\"\"\n", | |
" jax_label = f\"jax[{DEVICE}]_times\"\n", | |
" data = {\n", | |
" \"Ns\": [],\n", | |
" \"numba_times\": [],\n", | |
" jax_label: [],\n", | |
" \"cupy_times\": []\n", | |
" }\n", | |
"\n", | |
" Nfs = [2**i for i in range(min_n, max_n)]\n", | |
"\n", | |
" pbar = tqdm(Nfs, desc=\"Running experiments\")\n", | |
" for Nf in pbar:\n", | |
" pbar.set_postfix({\"Nf\": Nf})\n", | |
" numba_times = run_numba_timings(Nf=Nf, n_rep=n_rep)\n", | |
" jax_times = run_jax_timings(Nf=Nf, n_rep=n_rep)\n", | |
" cupy_times = run_cupy_timings(Nf=Nf, n_rep=n_rep)\n", | |
"\n", | |
" data[\"Ns\"].extend([Nf] * n_rep)\n", | |
" data[\"numba_times\"].extend(numba_times)\n", | |
" data[jax_label].extend(jax_times)\n", | |
" data[\"cupy_times\"].extend(cupy_times)\n", | |
"\n", | |
" return pd.DataFrame(data)\n", | |
"\n", | |
"\n", | |
"\n", | |
"# Running the experiments and collecting data\n", | |
"data = run_experiments(max_n=14, n_rep=5)\n", | |
"data.to_csv(\"runtimes.csv\", index=False)" | |
], | |
"metadata": { | |
"id": "rXA0cHe5gahJ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import numpy as np\n", | |
"import os\n", | |
"import matplotlib.pyplot as plt\n", | |
"import pandas as pd\n", | |
"\n", | |
"def plot_runtime_vs_size(data, column_label='numba', color='blue', ax=None):\n", | |
" \"\"\"\n", | |
" Plots the runtime vs size graph with error bands for a given column in the dataset.\n", | |
"\n", | |
" Parameters:\n", | |
" - data: DataFrame containing the data to be plotted.\n", | |
" - column_label: Label for the column prefix (e.g., 'numba', 'jax', 'cupy') to plot.\n", | |
" - color: Color for the plot lines and fill area (default: 'blue').\n", | |
" - ax: Optional matplotlib axis to plot on. If None, a new figure and axis are created.\n", | |
"\n", | |
" Returns:\n", | |
" - ax: The axis with the plot.\n", | |
" \"\"\"\n", | |
"\n", | |
" df = data.copy()\n", | |
" grouped = df.groupby('Ns').agg(\n", | |
" mean=(f'{column_label}_times', 'mean'),\n", | |
" sem=(f'{column_label}_times', 'sem')\n", | |
" ).reset_index()\n", | |
" grouped['exponent'] = np.log2(grouped['Ns']).astype(int)\n", | |
"\n", | |
" # Create an axis if not provided\n", | |
" if ax is None:\n", | |
" fig, ax = plt.subplots(figsize=(4, 3))\n", | |
"\n", | |
"\n", | |
" # Plotting\n", | |
" ax.fill_between(grouped['exponent'],\n", | |
" grouped['mean'] - grouped['sem'],\n", | |
" grouped['mean'] + grouped['sem'],\n", | |
" color=color, alpha=0.3, label=f'{column_label.capitalize()}', lw=0)\n", | |
" ax.plot(grouped['exponent'], grouped['mean'], color=color, marker='')\n", | |
"\n", | |
" # Set labels, legend\n", | |
" ax.set_xlabel('$N_fxN_t$', fontsize=12)\n", | |
" ax.set_ylabel('Runtime [s]', fontsize=12)\n", | |
" ax.legend(fontsize=10, frameon=False)\n", | |
"\n", | |
" ax.grid(True, which=\"both\", linestyle=\"-\", alpha=0.2)\n", | |
" fmt = '$2^{e}x2^{e}$'\n", | |
" ax.set_xticks(grouped['exponent'])\n", | |
" ax.set_xticklabels([fmt.replace(\"e\", str(exp)) for exp in grouped['exponent']])\n", | |
" ax.tick_params(axis='x', rotation=45)\n", | |
" ax.set_yscale('log')\n", | |
" plt.tight_layout()\n", | |
" return ax\n", | |
"\n", | |
"if 'runtimes_gpu.csv' in os.listdir():\n", | |
" df_gpu = pd.read_csv(\"runtimes_gpu.csv\")\n", | |
" df_tpu = pd.read_csv(\"runtimes_tpu.csv\")\n", | |
" df = df_gpu.copy()\n", | |
" df['jax[TPU v2]_times'] = df_tpu['jax[TPU v2]_times']\n", | |
"else:\n", | |
" df = pd.read_csv(\"runtimes.csv\")\n", | |
"fig, ax = plt.subplots(figsize=(4, 3))\n", | |
"labels = [i.split(\"_times\")[0] for i in df.columns.values if 'times' in i]\n", | |
"for i, l in enumerate(labels):\n", | |
" plot_runtime_vs_size(df, column_label=l, color=f\"C{i}\", ax=ax)" | |
], | |
"metadata": { | |
"id": "bNoJIMuHFftq" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Note: we've used lower precision data for JAX + Cupy (this can be changed)." | |
], | |
"metadata": { | |
"id": "m5UPtRxf32cz" | |
} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment