Skip to content

Instantly share code, notes, and snippets.

@avivajpeyi
Last active April 28, 2025 19:51
Show Gist options
  • Save avivajpeyi/b0a99b3f54b6841fb37e60af312b58a4 to your computer and use it in GitHub Desktop.
Save avivajpeyi/b0a99b3f54b6841fb37e60af312b58a4 to your computer and use it in GitHub Desktop.
wdm_transform.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyOk7xD8r533f+eImNBxoHDY",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/avivajpeyi/b0a99b3f54b6841fb37e60af312b58a4/wdm_transform.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Comparing JAX and Numba WDM Transforms\n",
"\n",
"\n",
"## Common functions"
],
"metadata": {
"id": "H6YcKEr5mHIV"
}
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"id": "0gJu4ONiQ6sB",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4ad76cba-8b79-4422-ddc6-deb4939d91e8"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Running JAX on cpu\n"
]
}
],
"source": [
"\"\"\" COMMON FUNCS \"\"\"\n",
"import numpy as np\n",
"from numpy import fft\n",
"from scipy.special import betainc\n",
"from scipy.signal import chirp\n",
"from typing import List, Tuple\n",
"import jax\n",
"jax.config.update('jax_enable_x64', False)\n",
"\n",
"DEVICE = jax.devices()[0].device_kind\n",
"print(f\"Running JAX on {DEVICE}\")\n",
"\n",
"PI = np.pi\n",
"FREQ_RANGE = [20, 100]\n",
"\n",
"def phitilde_vec_norm(Nf: int, Nt: int, dt: float, d: float) -> np.ndarray:\n",
" ND = Nf * Nt\n",
" omegas = 2 * np.pi / ND * np.arange(0, Nt // 2 + 1)\n",
" u_phit = phitilde_vec(omegas, Nf, dt, d)\n",
" normalising_factor = np.pi ** (-1 / 2) # Ollie's normalising factor\n",
" return u_phit / (normalising_factor)\n",
"\n",
"def phitilde_vec(omega: np.ndarray, Nf: int, dt: float, d: float = 4.0) -> np.ndarray:\n",
" \"\"\"Compute phi_tilde(omega_i) array.\"\"\"\n",
" dF = 1.0 / (2 * Nf * dt)\n",
" dOmega = 2 * PI * dF\n",
" inverse_sqrt_dOmega = 1.0 / np.sqrt(dOmega)\n",
" A, B = dOmega / 4, 3 * dOmega / 4\n",
" if B <= 0:\n",
" raise ValueError(\"B must be greater than 0\")\n",
" phi = np.full_like(omega, inverse_sqrt_dOmega)\n",
" mask = (A <= np.abs(omega)) & (np.abs(omega) < A + B)\n",
" phi[mask] *= np.cos((PI / 2.0) * _nu_d(omega[mask], A, B, d))\n",
" return phi\n",
"\n",
"def _nu_d(omega: np.ndarray, A: float, B: float, d: float = 4.0) -> np.ndarray:\n",
" \"\"\"Compute the normalized incomplete beta function.\"\"\"\n",
" x = (np.abs(omega) - A) / B\n",
" return betainc(d, d, x)\n",
"\n",
"\n",
"def simulate_data(Nf):\n",
" # assert Nf is power of 2\n",
" assert Nf & (Nf - 1) == 0, \"Nf must be a power of 2\"\n",
" fs = 512\n",
" dt = 1 / fs\n",
" Nt = Nf\n",
" mult = 16\n",
" nx = 4.0\n",
" ND = Nt * Nf\n",
" t = np.arange(0, ND) * dt\n",
" y = chirp(t, f0=FREQ_RANGE[0], f1=FREQ_RANGE[1], t1=t[-1], method=\"quadratic\")\n",
" phif = phitilde_vec_norm(Nf, Nt, dt=dt, d=nx)\n",
" yf = fft.fft(y)[:ND//2+1]\n",
"\n",
"\n",
"\n",
" plt.plot(phif)\n",
"\n",
"\n",
" return yf, phif\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"source": [
"## NUMBA Transform"
],
"metadata": {
"id": "4JhI5f_A3p2V"
}
},
{
"cell_type": "code",
"source": [
"\"\"\"NUMBA VERSION\"\"\"\n",
"import numpy as np\n",
"from numba import njit\n",
"from numpy import fft\n",
"\n",
"\n",
"def transform_wavelet_freq_helper_numba(\n",
" data: np.ndarray, Nf: int, Nt: int, phif: np.ndarray\n",
") -> np.ndarray:\n",
" \"\"\"helper to do the wavelet transform using the fast wavelet domain transform\"\"\"\n",
" wave = np.zeros((Nt, Nf)) # wavelet wavepacket transform of the signal\n",
"\n",
" DX = np.zeros(Nt, dtype=np.complex128)\n",
" freq_strain = data.copy() # Convert\n",
" for f_bin in range(0, Nf + 1):\n",
" __fill_wave_1_numba(f_bin, Nt, Nf, DX, freq_strain, phif)\n",
" DX_trans = fft.ifft(\n",
" DX, Nt\n",
" ) # A fix because numba doesn't support np.fft\n",
" __fill_wave_2_numba(f_bin, DX_trans, wave, Nt, Nf)\n",
"\n",
" return wave\n",
"\n",
"\n",
"@njit()\n",
"def __fill_wave_1_numba(\n",
" f_bin: int,\n",
" Nt: int,\n",
" Nf: int,\n",
" DX: np.ndarray,\n",
" data: np.ndarray,\n",
" phif: np.ndarray,\n",
") -> None:\n",
" \"\"\"helper for assigning DX in the main loop\"\"\"\n",
" i_base = Nt // 2\n",
" jj_base = f_bin * Nt // 2\n",
"\n",
" if f_bin == 0 or f_bin == Nf:\n",
" # NOTE this term appears to be needed to recover correct constant (at least for m=0), but was previously missing\n",
" DX[Nt // 2] = phif[0] * data[f_bin * Nt // 2] / 2.0\n",
" else:\n",
" DX[Nt // 2] = phif[0] * data[f_bin * Nt // 2]\n",
"\n",
" for jj in range(jj_base + 1 - Nt // 2, jj_base + Nt // 2):\n",
" j = np.abs(jj - jj_base)\n",
" i = i_base - jj_base + jj\n",
" if f_bin == Nf and jj > jj_base:\n",
" DX[i] = 0.0\n",
" elif f_bin == 0 and jj < jj_base:\n",
" DX[i] = 0.0\n",
" elif j == 0:\n",
" continue\n",
" else:\n",
" DX[i] = phif[j] * data[jj]\n",
"\n",
"\n",
"@njit()\n",
"def __fill_wave_2_numba(\n",
" f_bin: int, DX_trans: np.ndarray, wave: np.ndarray, Nt: int, Nf: int\n",
") -> None:\n",
" if f_bin == 0:\n",
" # half of lowest and highest frequency bin pixels are redundant, so store them in even and odd components of f_bin=0 respectively\n",
" for n in range(0, Nt, 2):\n",
" wave[n, 0] = np.real(DX_trans[n] * np.sqrt(2))\n",
" elif f_bin == Nf:\n",
" for n in range(0, Nt, 2):\n",
" wave[n + 1, 0] = np.real(DX_trans[n] * np.sqrt(2))\n",
" else:\n",
" for n in range(0, Nt):\n",
" if f_bin % 2:\n",
" if (n + f_bin) % 2:\n",
" wave[n, f_bin] = -np.imag(DX_trans[n])\n",
" else:\n",
" wave[n, f_bin] = np.real(DX_trans[n])\n",
" else:\n",
" if (n + f_bin) % 2:\n",
" wave[n, f_bin] = np.imag(DX_trans[n])\n",
" else:\n",
" wave[n, f_bin] = np.real(DX_trans[n])\n",
"\n",
"\n",
"\n"
],
"metadata": {
"id": "RkJzTw0t3pPw"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## JAX WDM transform"
],
"metadata": {
"id": "u9mBzg9Vmhgm"
}
},
{
"cell_type": "code",
"source": [
"\"\"\"JAX VERSION\"\"\"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from functools import partial\n",
"from jax import jit\n",
"from jax.numpy.fft import ifft\n",
"\n",
"@partial(jit, static_argnames=('Nf', 'Nt'))\n",
"def transform_wavelet_freq_helper_JAX(\n",
" data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray\n",
") -> jnp.ndarray:\n",
" wave = jnp.zeros((Nt, Nf))\n",
" f_bins = jnp.arange(Nf)\n",
"\n",
" i_base = Nt // 2\n",
" jj_base = f_bins * Nt // 2\n",
"\n",
" initial_values = jnp.where(\n",
" (f_bins == 0) | (f_bins == Nf),\n",
" phif[0] * data[f_bins * Nt // 2] / 2.0,\n",
" phif[0] * data[f_bins * Nt // 2]\n",
" )\n",
"\n",
" DX = jnp.zeros((Nf, Nt), dtype=jnp.complex64)\n",
" DX = DX.at[:, Nt // 2].set(initial_values)\n",
"\n",
" j_range = jnp.arange(1 - Nt // 2, Nt // 2)\n",
" j = jnp.abs(j_range)\n",
" i = i_base + j_range\n",
"\n",
" cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0)\n",
" cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0)\n",
" cond3 = j[None, :] == 0\n",
"\n",
" jj = jj_base[:, None] + j_range[None, :]\n",
" val = jnp.where(cond1 | cond2, 0.0, phif[j] * data[jj])\n",
" DX = DX.at[:, i].set(jnp.where(cond3, DX[:, i], val))\n",
"\n",
" # Vectorized ifft\n",
" DX_trans = ifft(DX, axis=1)\n",
"\n",
" # Vectorized __fill_wave_2_jax\n",
" n_range = jnp.arange(Nt)\n",
" cond1 = (n_range[:, None] + f_bins[None, :]) % 2 == 1\n",
" cond2 = f_bins % 2 == 1\n",
"\n",
" real_part = jnp.where(cond2, -jnp.imag(DX_trans), jnp.real(DX_trans))\n",
" imag_part = jnp.where(cond2, jnp.real(DX_trans), jnp.imag(DX_trans))\n",
"\n",
" wave = jnp.where(cond1, imag_part, real_part)\n",
"\n",
" # Special cases for f_bin 0 and Nf\n",
" wave = wave.at[::2, 0].set(jnp.real(DX_trans[0, ::2] * jnp.sqrt(2)))\n",
" wave = wave.at[1::2, -1].set(jnp.real(DX_trans[-1, ::2] * jnp.sqrt(2)))\n",
"\n",
" return wave.T\n"
],
"metadata": {
"id": "Cf4wt-8dX9Xb"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Cupy version"
],
"metadata": {
"id": "b_5FzcvgLHr2"
}
},
{
"cell_type": "code",
"source": [
"try:\n",
" import cupy as cp\n",
" cupy_available = True\n",
"except:\n",
" cupy_available = False\n",
" cp = np\n",
"\n",
"\n",
"\n",
"def transform_wavelet_freq_helper_CuPy(data: cp.ndarray, Nf: int, Nt: int, phif: cp.ndarray) -> cp.ndarray:\n",
" if not cupy_available:\n",
" return None\n",
"\n",
" wave = cp.zeros((Nt, Nf), dtype=cp.complex64)\n",
" f_bins = cp.arange(Nf)\n",
"\n",
" # Base indices\n",
" i_base = Nt // 2\n",
" jj_base = f_bins * Nt // 2\n",
"\n",
" # Initial values\n",
" initial_values = cp.where(\n",
" (f_bins == 0) | (f_bins == Nf),\n",
" phif[0] * data[f_bins * Nt // 2] / 2.0,\n",
" phif[0] * data[f_bins * Nt // 2]\n",
" )\n",
"\n",
" DX = cp.zeros((Nf, Nt), dtype=cp.complex64)\n",
" DX[:, Nt // 2] = initial_values\n",
"\n",
" j_range = cp.arange(1 - Nt // 2, Nt // 2)\n",
" j = cp.abs(j_range)\n",
" i = i_base + j_range\n",
"\n",
" cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0)\n",
" cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0)\n",
" cond3 = j[None, :] == 0\n",
"\n",
" jj = jj_base[:, None] + j_range[None, :]\n",
" val = cp.where(cond1 | cond2, 0.0, phif[j] * data[jj])\n",
" DX[:, i] = cp.where(cond3, DX[:, i], val)\n",
"\n",
" # Vectorized ifft using CuPy\n",
" DX_trans = cp.fft.ifft(DX, axis=1)\n",
"\n",
" # Vectorized __fill_wave_2\n",
" n_range = cp.arange(Nt)\n",
" cond1 = (n_range[:, None] + f_bins[None, :]) % 2 == 1\n",
" cond2 = f_bins % 2 == 1\n",
"\n",
" real_part = cp.where(cond2, -cp.imag(DX_trans), cp.real(DX_trans))\n",
" imag_part = cp.where(cond2, cp.real(DX_trans), cp.imag(DX_trans))\n",
"\n",
" wave = cp.where(cond1, imag_part, real_part)\n",
"\n",
" # Special cases for f_bin 0 and Nf\n",
" wave[::2, 0] = cp.real(DX_trans[0, ::2] * cp.sqrt(2))\n",
" wave[1::2, -1] = cp.real(DX_trans[-1, ::2] * cp.sqrt(2))\n",
"\n",
" return wave.T"
],
"metadata": {
"id": "EWQ8ugU6LGwe"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"\n",
"KERNEL_CODE = r'''\n",
"extern \"C\" __global__\n",
"void wavelet_kernel(\n",
" const float* data, const float* phif, int Nt, int Nf, float* DX_real, float* DX_imag) {\n",
"\n",
" int f = blockIdx.x * blockDim.x + threadIdx.x; // Frequency index\n",
" int t = blockIdx.y * blockDim.y + threadIdx.y; // Time index\n",
"\n",
" if (f < Nf && t < Nt) {\n",
" int i_base = Nt / 2;\n",
" int jj_base = f * Nt / 2;\n",
"\n",
" // Initial values logic\n",
" if (t == Nt / 2) {\n",
" float init_val;\n",
" if (f == 0 || f == Nf - 1) {\n",
" init_val = phif[0] * data[jj_base] / 2.0;\n",
" } else {\n",
" init_val = phif[0] * data[jj_base];\n",
" }\n",
" DX_real[f * Nt + t] = init_val;\n",
" DX_imag[f * Nt + t] = 0.0f; // Initialize imaginary part to zero\n",
" }\n",
" }\n",
"}\n",
"'''\n",
"\n",
"\n",
"if cupy_available:\n",
" wavelet_kernel = cp.RawKernel(KERNEL_CODE, 'wavelet_kernel')\n",
"\n",
"import cupy as cp\n",
"\n",
"# Define the custom CUDA kernel\n",
"\n",
"\n",
"def transform_wavelet_freq_helper_CUDA(data: cp.ndarray, Nf: int, Nt: int, phif: cp.ndarray) -> cp.ndarray:\n",
" if not cupy_available:\n",
" return None\n",
" # Prepare output arrays\n",
" wave = cp.zeros((Nt, Nf), dtype=cp.complex64)\n",
"\n",
" # Prepare DX as separate real and imaginary parts\n",
" DX_real = cp.zeros((Nf, Nt), dtype=cp.float32)\n",
" DX_imag = cp.zeros((Nf, Nt), dtype=cp.float32)\n",
"\n",
" # Launch the CUDA kernel to populate DX\n",
" threads_per_block = (16, 16)\n",
" blocks_per_grid = ((Nf + threads_per_block[0] - 1) // threads_per_block[0],\n",
" (Nt + threads_per_block[1] - 1) // threads_per_block[1])\n",
"\n",
" # Unpack the blocks_per_grid tuple\n",
" wavelet_kernel(\n",
" blocks_per_grid, threads_per_block,\n",
" (data.data.ptr, phif.data.ptr, Nt, Nf, DX_real.data.ptr, DX_imag.data.ptr)\n",
" )\n",
"\n",
" # Perform inverse FFT on DX using CuPy's optimized FFT\n",
" DX = DX_real + 1j * DX_imag\n",
" DX_trans = cp.fft.ifft(DX, axis=1)\n",
"\n",
" # Fill the wave array using CuPy\n",
" f_bins = cp.arange(Nf)\n",
" n_range = cp.arange(Nt)\n",
"\n",
" # Condition logic (this part may not require CUDA kernel)\n",
" cond1 = (n_range[:, None] + f_bins[None, :]) % 2 == 1\n",
" cond2 = f_bins % 2 == 1\n",
"\n",
" real_part = cp.where(cond2, -cp.imag(DX_trans), cp.real(DX_trans))\n",
" imag_part = cp.where(cond2, cp.real(DX_trans), cp.imag(DX_trans))\n",
"\n",
" wave = cp.where(cond1, imag_part, real_part)\n",
"\n",
" # Special cases for f_bin 0 and Nf\n",
" wave[::2, 0] = cp.real(DX_trans[0, ::2] * cp.sqrt(2))\n",
" wave[1::2, -1] = cp.real(DX_trans[-1, ::2] * cp.sqrt(2))\n",
"\n",
" return wave.T\n",
"\n"
],
"metadata": {
"id": "TNZzHj2zMa2f"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"# collect plotting data\n",
"Nf = Nt = 2 ** 6\n",
"yf, phif = simulate_data(Nf)\n",
"wave = transform_wavelet_freq_helper_numba(yf, Nf, Nt, phif)\n",
"jax_yf = jnp.array(yf)\n",
"jax_phif = jnp.array(phif)\n",
"wave_jax = transform_wavelet_freq_helper_JAX(jax_yf, Nf, Nt, jax_phif)\n",
"nplots = 3\n",
"\n",
"print(f\"Phif shape {phif.shape}\")\n",
"print(f\"data shape {jax_yf.shape}\")\n",
"\n",
"\n",
"\n",
"def xn_i(m:int, data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray) -> jnp.ndarray:\n",
"\n",
"\n",
" return jnp.fft.ifft(jnp.roll(data, - (m* Nt//2))[:Nt//2] * phif)\n",
"\n",
"\n",
"def transform_conv(data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray) -> jnp.ndarray:\n",
"\n",
" # fftshifft data to match up with Neil Eq 17\n",
" data_shifted = jnp.fft.fftshift(data)\n",
"\n",
" wave= jax.lax.fori_loop(\n",
" 0, Nt,\n",
" lambda i, accum: jnp.stack((accum, xn_i(i, data, Nf, Nt, phif)), 1),\n",
" jnp.zeros((Nt,1), dtype=jnp.complex64)\n",
" )\n",
" return wave[:,1:]\n",
"\n",
"\n",
"# @partial(jit, static_argnames=('Nf', 'Nt'))\n",
"# def transform_conv(\n",
"# data: jnp.ndarray, Nf: int, Nt: int, phif: jnp.ndarray\n",
"# ) -> jnp.ndarray:\n",
"\n",
"\n",
"\n",
"\n",
"# wave = jnp.zeros((Nt, Nf))\n",
"# return\n",
"\n",
"\n",
"\n",
"\n",
" # i_base = Nt // 2\n",
" # jj_base = f_bins * Nt // 2\n",
"\n",
" # initial_values = jnp.where(\n",
" # (f_bins == 0) | (f_bins == Nf),\n",
" # phif[0] * data[f_bins * Nt // 2] / 2.0,\n",
" # phif[0] * data[f_bins * Nt // 2]\n",
" # )\n",
"\n",
" # DX = jnp.zeros((Nf, Nt), dtype=jnp.complex64)\n",
" # DX = DX.at[:, Nt // 2].set(initial_values)\n",
"\n",
" # j_range = jnp.arange(1 - Nt // 2, Nt // 2)\n",
" # j = jnp.abs(j_range)\n",
" # i = i_base + j_range\n",
"\n",
" # cond1 = (f_bins[:, None] == Nf) & (j_range[None, :] > 0)\n",
" # cond2 = (f_bins[:, None] == 0) & (j_range[None, :] < 0)\n",
" # cond3 = j[None, :] == 0\n",
"\n",
" # jj = jj_base[:, None] + j_range[None, :]\n",
" # val = jnp.where(cond1 | cond2, 0.0, phif[j] * data[jj])\n",
" # DX = DX.at[:, i].set(jnp.where(cond3, DX[:, i], val))\n",
"\n",
" # # Vectorized ifft\n",
" # DX_trans = ifft(DX, axis=1)\n",
"\n",
" # # Vectorized __fill_wave_2_jax\n",
" # n_range = jnp.arange(Nt)\n",
" # cond1 = (n_range[:, None] + f_bins[None, :]) % 2 == 1\n",
" # cond2 = f_bins % 2 == 1\n",
"\n",
" # real_part = jnp.where(cond2, -jnp.imag(DX_trans), jnp.real(DX_trans))\n",
" # imag_part = jnp.where(cond2, jnp.real(DX_trans), jnp.imag(DX_trans))\n",
"\n",
" # wave = jnp.where(cond1, imag_part, real_part)\n",
"\n",
" # # Special cases for f_bin 0 and Nf\n",
" # wave = wave.at[::2, 0].set(jnp.real(DX_trans[0, ::2] * jnp.sqrt(2)))\n",
" # wave = wave.at[1::2, -1].set(jnp.real(DX_trans[-1, ::2] * jnp.sqrt(2)))\n",
"\n",
" # return wave.T\n",
"\n",
"\n",
"\n",
"wave_conv = transform_conv(jax_yf, Nf, Nt, jax_phif)\n",
"\n",
"# if cupy_available:\n",
"# cupy_yf = cp.array(yf)\n",
"# cupy_phif = cp.array(phif)\n",
"# wave_cupy = transform_wavelet_freq_helper_CuPy(cupy_yf, Nf, Nt, cupy_phif).get()\n",
"# wave_CUDA = transform_wavelet_freq_helper_CUDA(cupy_yf, Nf, Nt, cupy_phif).get()\n",
"# nplots += 2\n",
"\n",
"# render plot\n",
"fig, ax = plt.subplots(1, nplots, figsize=(5, 5), sharex=True, sharey=True)\n",
"ax[0].imshow(np.abs(np.rot90(wave)))\n",
"ax[0].set_title(\"Numba\")\n",
"ax[1].imshow(np.abs(np.rot90(wave_jax)))\n",
"ax[1].set_title(\"Jax\")\n",
"ax[2].imshow(np.abs(np.rot90(wave_conv)))\n",
"ax[2].set_title(\"JaxCONV\")\n",
"# if cupy_available:\n",
"# ax[2].imshow(np.abs(np.rot90(wave_cupy)))\n",
"# ax[2].set_title(\"Cupy\")\n",
"# ax[3].imshow(np.abs(np.rot90(wave_CUDA)))\n",
"# ax[3].set_title(\"CUDA\")\n",
"\n",
"for a in ax:\n",
" a.set_xticks([])\n",
" a.set_yticks([])\n",
" # a.set_ylim(120, 70)\n",
"plt.tight_layout()\n",
"plt.show()"
],
"metadata": {
"id": "Tx1yiDeFh7Yl",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 919
},
"outputId": "3e427137-57b7-4bfb-88e1-ef2de5522bc2"
},
"execution_count": 28,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Phif shape (33,)\n",
"data shape (2049,)\n"
]
},
{
"output_type": "error",
"ename": "TypeError",
"evalue": "mul got incompatible shapes for broadcasting: (32,), (33,).",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-28-9015245b2ced>\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m \u001b[0mwave_conv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform_conv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjax_yf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjax_phif\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0;31m# if cupy_available:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-28-9015245b2ced>\u001b[0m in \u001b[0;36mtransform_conv\u001b[0;34m(data, Nf, Nt, phif)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtransform_conv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNf\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNt\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mphif\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m wave= jax.lax.fori_loop(\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNt\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maccum\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maccum\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxn_i\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mphif\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
" \u001b[0;31m[... skipping hidden 13 frame]\u001b[0m\n",
"\u001b[0;32m<ipython-input-28-9015245b2ced>\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(i, accum)\u001b[0m\n\u001b[1;32m 26\u001b[0m wave= jax.lax.fori_loop(\n\u001b[1;32m 27\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNt\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0;32mlambda\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maccum\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maccum\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxn_i\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mphif\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 29\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mNt\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcomplex64\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m )\n",
"\u001b[0;32m<ipython-input-28-9015245b2ced>\u001b[0m in \u001b[0;36mxn_i\u001b[0;34m(m, data, Nf, Nt, phif)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfft\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mifft\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroll\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m*\u001b[0m \u001b[0mNt\u001b[0m\u001b[0;34m//\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mNt\u001b[0m\u001b[0;34m//\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mphif\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/array_methods.py\u001b[0m in \u001b[0;36mop\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 1058\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_forward_operator_to_aval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1059\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1060\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34mf\"_{name}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1061\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1062\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/array_methods.py\u001b[0m in \u001b[0;36mdeferring_binary_op\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 577\u001b[0m \u001b[0margs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mswap\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 578\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_accepted_binop_types\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 579\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mbinary_op\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 580\u001b[0m \u001b[0;31m# Note: don't use isinstance here, because we don't want to raise for\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 581\u001b[0m \u001b[0;31m# subclasses, e.g. NamedTuple objects that may override operators.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/ufunc_api.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, out, where, *args)\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"where argument of {self}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 179\u001b[0m \u001b[0mcall\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__static_props\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'call'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_vectorized\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 180\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 181\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 182\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstatic_argnames\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'self'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
" \u001b[0;31m[... skipping hidden 13 frame]\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/ufuncs.py\u001b[0m in \u001b[0;36mmultiply\u001b[0;34m(x, y)\u001b[0m\n\u001b[1;32m 1254\u001b[0m \"\"\"\n\u001b[1;32m 1255\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpromote_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"multiply\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1256\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mlax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mlax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbitwise_and\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1257\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1258\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
" \u001b[0;31m[... skipping hidden 9 frame]\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/jax/_src/lax/lax.py\u001b[0m in \u001b[0;36m_try_broadcast_shapes\u001b[0;34m(name, *shapes)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0mresult_shape\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnon_1s\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m raise TypeError(f'{name} got incompatible shapes for broadcasting: '\n\u001b[0m\u001b[1;32m 133\u001b[0m f'{\", \".join(map(str, map(tuple, shapes)))}.')\n\u001b[1;32m 134\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: mul got incompatible shapes for broadcasting: (32,), (33,)."
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"phif"
],
"metadata": {
"id": "sYZcibSW7mbF",
"outputId": "e33bacc9-e926-40b2-f338-841e2496e38f",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 26,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([0.35355339, 0.35355339, 0.35355339, 0.35355339, 0.35355339,\n",
" 0.35355339, 0.35355339, 0.35355339, 0.35355339, 0.35355339,\n",
" 0.35355339, 0.35355339, 0.35355339, 0.35355339, 0.35355339,\n",
" 0.35355339, 0.35355339, 0.35355339, 0.35355339, 0.35355339,\n",
" 0.35355339, 0.35355339, 0.35355339, 0.35355339, 0.35355339,\n",
" 0.35355339, 0.35355339, 0.35355339, 0.35355339, 0.35355339,\n",
" 0.35355339, 0.35355339, 0.35355339])"
]
},
"metadata": {},
"execution_count": 26
}
]
},
{
"cell_type": "code",
"source": [
"1 - Nt // 2, Nt // 2"
],
"metadata": {
"id": "7BP4UO8O5qh-",
"outputId": "b3dba92b-5693-449e-bf2f-58bad8779e6c",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 22,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(-31, 32)"
]
},
"metadata": {},
"execution_count": 22
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "yoeuSRb46J4_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"phif"
],
"metadata": {
"id": "5wv3FDxY7d8n",
"outputId": "2e66e3b1-7f09-438c-a1f6-50c25d72a289",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 24,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([0.35355339, 0.35355339, 0.35355339, 0.35355339, 0.35355339,\n",
" 0.35355339, 0.35355339, 0.35355339, 0.35355339, 0.35355339,\n",
" 0.35355339, 0.35355339, 0.35355339, 0.35355339, 0.35355339,\n",
" 0.35355339, 0.35355339, 0.35355339, 0.35355339, 0.35355339,\n",
" 0.35355339, 0.35355339, 0.35355339, 0.35355339, 0.35355339,\n",
" 0.35355339, 0.35355339, 0.35355339, 0.35355339, 0.35355339,\n",
" 0.35355339, 0.35355339, 0.35355339])"
]
},
"metadata": {},
"execution_count": 24
}
]
},
{
"cell_type": "markdown",
"source": [
"## Timing comparisons"
],
"metadata": {
"id": "v8MIBbNHmndU"
}
},
{
"cell_type": "code",
"source": [
"%timeit transform_wavelet_freq_helper_numba(yf, Nf, Nt, phif)"
],
"metadata": {
"id": "yT8KZlsZjFZq",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4d79e0d4-9276-44d6-aa1f-6bffddc0c690"
},
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"872 µs ± 14.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"%timeit transform_wavelet_freq_helper_JAX(jax_yf, Nf, Nt, jax_phif)"
],
"metadata": {
"id": "X3ObJ_BrjIHC"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"%timeit transform_wavelet_freq_helper_CuPy(cupy_yf, Nf, Nt, cupy_phif)"
],
"metadata": {
"id": "UdHonCg3OaT9"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import time\n",
"from tqdm.auto import tqdm\n",
"import pandas as pd\n",
"import numpy as np\n",
"import jax.numpy as jnp\n",
"\n",
"\n",
"def time_function(func, *args):\n",
" try:\n",
" t0 = time.time()\n",
" func(*args)\n",
" return time.time() - t0\n",
" except Exception:\n",
" return np.nan\n",
"\n",
"\n",
"def run_numba_timings(Nf=2**6, Nt=None, n_rep=10):\n",
" \"\"\"Run timing for Numba implementation.\"\"\"\n",
" Nt = Nt or Nf\n",
" yf, phif = simulate_data(Nf)\n",
" # Warmup\n",
" transform_wavelet_freq_helper_numba(yf, Nf, Nt, phif)\n",
"\n",
" times = np.zeros(n_rep)\n",
" for i in range(n_rep):\n",
" times[i] = time_function(transform_wavelet_freq_helper_numba, yf, Nf, Nt, phif)\n",
" return times\n",
"\n",
"def run_jax_timings(Nf=2**6, Nt=None, n_rep=10):\n",
" \"\"\"Run timing for JAX implementation.\"\"\"\n",
" Nt = Nt or Nf\n",
" yf, phif = simulate_data(Nf)\n",
" jax_yf, jax_phif = jnp.array(yf), jnp.array(phif)\n",
" # Warmup\n",
" transform_wavelet_freq_helper_JAX(jax_yf, Nf, Nt, jax_phif)\n",
"\n",
" times = np.zeros(n_rep)\n",
" for i in range(n_rep):\n",
" times[i] = time_function(transform_wavelet_freq_helper_JAX, jax_yf, Nf, Nt, jax_phif)\n",
" return times\n",
"\n",
"def run_cupy_timings(Nf=2**6, Nt=None, n_rep=10):\n",
" \"\"\"Run timing for CuPy implementation.\"\"\"\n",
" Nt = Nt or Nf\n",
"\n",
" if (Nf >= 8192 # Cupy seems to run out of mem here.\n",
" or not cupy_available):\n",
" return np.array([np.nan] * n_rep)\n",
"\n",
" yf, phif = simulate_data(Nf)\n",
" cupy_yf, cupy_phif = cp.array(yf), cp.array(phif)\n",
" # Warmup\n",
" transform_wavelet_freq_helper_CuPy(cupy_yf, Nf, Nt, cupy_phif)\n",
"\n",
" times = np.zeros(n_rep)\n",
" for i in range(n_rep):\n",
" times[i] = time_function(transform_wavelet_freq_helper_CuPy, cupy_yf, Nf, Nt, cupy_phif)\n",
" return times\n",
"\n",
"def run_experiments(min_n=5, max_n=10, n_rep=10):\n",
" \"\"\"\n",
" Run experiments for different sizes and collect runtime data.\n",
"\n",
" Parameters:\n",
" - min_n: Minimum power of 2 for the size of Nf.\n",
" - max_n: Maximum power of 2 for the size of Nf.\n",
" - n_rep: Number of repetitions for each size.\n",
"\n",
" Returns:\n",
" - pd.DataFrame: DataFrame containing runtimes for different methods and sizes.\n",
" \"\"\"\n",
" jax_label = f\"jax[{DEVICE}]_times\"\n",
" data = {\n",
" \"Ns\": [],\n",
" \"numba_times\": [],\n",
" jax_label: [],\n",
" \"cupy_times\": []\n",
" }\n",
"\n",
" Nfs = [2**i for i in range(min_n, max_n)]\n",
"\n",
" pbar = tqdm(Nfs, desc=\"Running experiments\")\n",
" for Nf in pbar:\n",
" pbar.set_postfix({\"Nf\": Nf})\n",
" numba_times = run_numba_timings(Nf=Nf, n_rep=n_rep)\n",
" jax_times = run_jax_timings(Nf=Nf, n_rep=n_rep)\n",
" cupy_times = run_cupy_timings(Nf=Nf, n_rep=n_rep)\n",
"\n",
" data[\"Ns\"].extend([Nf] * n_rep)\n",
" data[\"numba_times\"].extend(numba_times)\n",
" data[jax_label].extend(jax_times)\n",
" data[\"cupy_times\"].extend(cupy_times)\n",
"\n",
" return pd.DataFrame(data)\n",
"\n",
"\n",
"\n",
"# Running the experiments and collecting data\n",
"data = run_experiments(max_n=14, n_rep=5)\n",
"data.to_csv(\"runtimes.csv\", index=False)"
],
"metadata": {
"id": "rXA0cHe5gahJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import os\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"def plot_runtime_vs_size(data, column_label='numba', color='blue', ax=None):\n",
" \"\"\"\n",
" Plots the runtime vs size graph with error bands for a given column in the dataset.\n",
"\n",
" Parameters:\n",
" - data: DataFrame containing the data to be plotted.\n",
" - column_label: Label for the column prefix (e.g., 'numba', 'jax', 'cupy') to plot.\n",
" - color: Color for the plot lines and fill area (default: 'blue').\n",
" - ax: Optional matplotlib axis to plot on. If None, a new figure and axis are created.\n",
"\n",
" Returns:\n",
" - ax: The axis with the plot.\n",
" \"\"\"\n",
"\n",
" df = data.copy()\n",
" grouped = df.groupby('Ns').agg(\n",
" mean=(f'{column_label}_times', 'mean'),\n",
" sem=(f'{column_label}_times', 'sem')\n",
" ).reset_index()\n",
" grouped['exponent'] = np.log2(grouped['Ns']).astype(int)\n",
"\n",
" # Create an axis if not provided\n",
" if ax is None:\n",
" fig, ax = plt.subplots(figsize=(4, 3))\n",
"\n",
"\n",
" # Plotting\n",
" ax.fill_between(grouped['exponent'],\n",
" grouped['mean'] - grouped['sem'],\n",
" grouped['mean'] + grouped['sem'],\n",
" color=color, alpha=0.3, label=f'{column_label.capitalize()}', lw=0)\n",
" ax.plot(grouped['exponent'], grouped['mean'], color=color, marker='')\n",
"\n",
" # Set labels, legend\n",
" ax.set_xlabel('$N_fxN_t$', fontsize=12)\n",
" ax.set_ylabel('Runtime [s]', fontsize=12)\n",
" ax.legend(fontsize=10, frameon=False)\n",
"\n",
" ax.grid(True, which=\"both\", linestyle=\"-\", alpha=0.2)\n",
" fmt = '$2^{e}x2^{e}$'\n",
" ax.set_xticks(grouped['exponent'])\n",
" ax.set_xticklabels([fmt.replace(\"e\", str(exp)) for exp in grouped['exponent']])\n",
" ax.tick_params(axis='x', rotation=45)\n",
" ax.set_yscale('log')\n",
" plt.tight_layout()\n",
" return ax\n",
"\n",
"if 'runtimes_gpu.csv' in os.listdir():\n",
" df_gpu = pd.read_csv(\"runtimes_gpu.csv\")\n",
" df_tpu = pd.read_csv(\"runtimes_tpu.csv\")\n",
" df = df_gpu.copy()\n",
" df['jax[TPU v2]_times'] = df_tpu['jax[TPU v2]_times']\n",
"else:\n",
" df = pd.read_csv(\"runtimes.csv\")\n",
"fig, ax = plt.subplots(figsize=(4, 3))\n",
"labels = [i.split(\"_times\")[0] for i in df.columns.values if 'times' in i]\n",
"for i, l in enumerate(labels):\n",
" plot_runtime_vs_size(df, column_label=l, color=f\"C{i}\", ax=ax)"
],
"metadata": {
"id": "bNoJIMuHFftq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Note: we've used lower precision data for JAX + Cupy (this can be changed)."
],
"metadata": {
"id": "m5UPtRxf32cz"
}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment