Created
March 15, 2023 16:04
-
-
Save bmorris3/ab6f82ff3780abc2e78fea434dd78b87 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "eab1828a-98c0-4152-a5f9-c5a93323e922", | |
"metadata": {}, | |
"source": [ | |
"# Remove RVs and coadd model spectra of flux vs orbital phase" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "4d8eed62-3b53-4d51-b4b5-ba1b863a0f66", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"from jax import lax, numpy as jnp\n", | |
"\n", | |
"x = np.linspace(0, 10, 30)\n", | |
"\n", | |
"wavelength_shifts = np.linspace(0, 1, 20)\n", | |
"\n", | |
"y = np.vstack([\n", | |
" np.cos(2 * np.pi / 4 * x_shifted) \n", | |
" for x_shifted in x - wavelength_shifts[:, None]\n", | |
"])\n", | |
"\n", | |
"plt.imshow(y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "7cae7dc2-eb92-489f-aee4-b3348415e733", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# in this silly example, I'm using the same values for \n", | |
"# the velocity and the wavelength axes. Of course, you'll\n", | |
"# have different numbers for these two arrays.\n", | |
"velocity = x.copy()\n", | |
"wavelength = x.copy()\n", | |
"flux = y.copy()\n", | |
"\n", | |
"# I'm pretending that we're interpolating the model fluxes onto \n", | |
"# an \"observed\" wavelength grid which is only slightly different \n", | |
"# from the original wavelength axis:\n", | |
"observed_wavelength = x + np.random.normal(scale=0.01, size=x.size)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "ff4c4192-32cb-428b-8b9c-b6ecbf59c897", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def shift_and_coadd(\n", | |
" observed_wavelength, wavelength_shifts, velocity, wavelength, flux\n", | |
"):\n", | |
" \"\"\"\n", | |
" 1. Iterate over each velocity\n", | |
" 2. Interpolate the model fluxes at each v onto a different wl grid\n", | |
" 3. Sum up the result over all velocities\n", | |
" \"\"\"\n", | |
" def iterate_velocities(\n", | |
" carry, j, \n", | |
" observed_wavelength=jnp.array(observed_wavelength),\n", | |
" wavelength_shifts=jnp.array(wavelength_shifts),\n", | |
" velocity=jnp.array(velocity), \n", | |
" wavelength=jnp.array(wavelength),\n", | |
" flux=jnp.array(flux)\n", | |
" ):\n", | |
"\n", | |
" return carry, jnp.interp(\n", | |
" # interp onto observed wavelength grid, shifted by the correct amount:\n", | |
" observed_wavelength + wavelength_shifts[j], \n", | |
" # this is the model's wavelength grid:\n", | |
" wavelength, \n", | |
" # this is the flux where the jth *row* corresponds to the jth velocity\n", | |
" flux[j], \n", | |
" # out of bounds settings:\n", | |
" left=jnp.nan, right=jnp.nan\n", | |
" )\n", | |
"\n", | |
" # this iterates over all velocities\n", | |
" flux_interp_sum = lax.scan(\n", | |
" iterate_velocities, 0.0, jnp.arange(len(velocity))\n", | |
" )[1]\n", | |
"\n", | |
" # now the sum of all columns is the result:\n", | |
" F = jnp.sum(flux_interp_sum, axis=0)\n", | |
" \n", | |
" return flux_interp_sum, F" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b772fe73-9dfd-451a-83c8-92c74728c0a1", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"flux_interp_sum, F = shift_and_coadd(\n", | |
" observed_wavelength, wavelength_shifts, velocity, wavelength, flux\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "9d8f7283-7961-4d75-b2e8-009c5ec3ae6a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"fig, ax = plt.subplots(2, 1, figsize=(4, 7), sharex=True)\n", | |
"ax[0].pcolormesh(wavelength, velocity, flux_interp_sum)\n", | |
"ax[0].set(\n", | |
" title=\"de-shifted, 2D\",\n", | |
" ylabel=\"velocity\"\n", | |
")\n", | |
"ax[1].plot(wavelength, F)\n", | |
"ax[1].set(\n", | |
" title=\"coadded\",\n", | |
" xlabel=\"wavelength\"\n", | |
")\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "2606d04b-b318-49e6-a20e-fa766ad27098", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment