Skip to content

Instantly share code, notes, and snippets.

@bmorris3
Created March 15, 2023 16:04
Show Gist options
  • Save bmorris3/ab6f82ff3780abc2e78fea434dd78b87 to your computer and use it in GitHub Desktop.
Save bmorris3/ab6f82ff3780abc2e78fea434dd78b87 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "eab1828a-98c0-4152-a5f9-c5a93323e922",
"metadata": {},
"source": [
"# Remove RVs and coadd model spectra of flux vs orbital phase"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4d8eed62-3b53-4d51-b4b5-ba1b863a0f66",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from jax import lax, numpy as jnp\n",
"\n",
"x = np.linspace(0, 10, 30)\n",
"\n",
"wavelength_shifts = np.linspace(0, 1, 20)\n",
"\n",
"y = np.vstack([\n",
" np.cos(2 * np.pi / 4 * x_shifted) \n",
" for x_shifted in x - wavelength_shifts[:, None]\n",
"])\n",
"\n",
"plt.imshow(y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7cae7dc2-eb92-489f-aee4-b3348415e733",
"metadata": {},
"outputs": [],
"source": [
"# in this silly example, I'm using the same values for \n",
"# the velocity and the wavelength axes. Of course, you'll\n",
"# have different numbers for these two arrays.\n",
"velocity = x.copy()\n",
"wavelength = x.copy()\n",
"flux = y.copy()\n",
"\n",
"# I'm pretending that we're interpolating the model fluxes onto \n",
"# an \"observed\" wavelength grid which is only slightly different \n",
"# from the original wavelength axis:\n",
"observed_wavelength = x + np.random.normal(scale=0.01, size=x.size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff4c4192-32cb-428b-8b9c-b6ecbf59c897",
"metadata": {},
"outputs": [],
"source": [
"def shift_and_coadd(\n",
" observed_wavelength, wavelength_shifts, velocity, wavelength, flux\n",
"):\n",
" \"\"\"\n",
" 1. Iterate over each velocity\n",
" 2. Interpolate the model fluxes at each v onto a different wl grid\n",
" 3. Sum up the result over all velocities\n",
" \"\"\"\n",
" def iterate_velocities(\n",
" carry, j, \n",
" observed_wavelength=jnp.array(observed_wavelength),\n",
" wavelength_shifts=jnp.array(wavelength_shifts),\n",
" velocity=jnp.array(velocity), \n",
" wavelength=jnp.array(wavelength),\n",
" flux=jnp.array(flux)\n",
" ):\n",
"\n",
" return carry, jnp.interp(\n",
" # interp onto observed wavelength grid, shifted by the correct amount:\n",
" observed_wavelength + wavelength_shifts[j], \n",
" # this is the model's wavelength grid:\n",
" wavelength, \n",
" # this is the flux where the jth *row* corresponds to the jth velocity\n",
" flux[j], \n",
" # out of bounds settings:\n",
" left=jnp.nan, right=jnp.nan\n",
" )\n",
"\n",
" # this iterates over all velocities\n",
" flux_interp_sum = lax.scan(\n",
" iterate_velocities, 0.0, jnp.arange(len(velocity))\n",
" )[1]\n",
"\n",
" # now the sum of all columns is the result:\n",
" F = jnp.sum(flux_interp_sum, axis=0)\n",
" \n",
" return flux_interp_sum, F"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b772fe73-9dfd-451a-83c8-92c74728c0a1",
"metadata": {},
"outputs": [],
"source": [
"flux_interp_sum, F = shift_and_coadd(\n",
" observed_wavelength, wavelength_shifts, velocity, wavelength, flux\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9d8f7283-7961-4d75-b2e8-009c5ec3ae6a",
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(2, 1, figsize=(4, 7), sharex=True)\n",
"ax[0].pcolormesh(wavelength, velocity, flux_interp_sum)\n",
"ax[0].set(\n",
" title=\"de-shifted, 2D\",\n",
" ylabel=\"velocity\"\n",
")\n",
"ax[1].plot(wavelength, F)\n",
"ax[1].set(\n",
" title=\"coadded\",\n",
" xlabel=\"wavelength\"\n",
")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2606d04b-b318-49e6-a20e-fa766ad27098",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment