Skip to content

Instantly share code, notes, and snippets.

@bmorris3
Created November 29, 2023 21:07
Show Gist options
  • Save bmorris3/4e51b0ce1bca5f3865700d4975788efb to your computer and use it in GitHub Desktop.
Save bmorris3/4e51b0ce1bca5f3865700d4975788efb to your computer and use it in GitHub Desktop.
Jax implementation of fleck, available in this branch: https://github.com/bmorris3/fleck/tree/jax
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "46d59629-0e60-4285-97e8-8e892b5e7d6b",
"metadata": {},
"source": [
"# fleck in JAX\n",
"\n",
"This notebook demonstrates computing spectroscopic rotational modulation: \n",
"* for $N$ active regions with one uniform contrast, defined by a blackbody spectrum, on one star\n",
"* for $N$ active regions with $M$ active region contrasts, defined by PHOENIX spectra, on one star\n",
"* for $N$ active regions with $M$ active region contrasts, on $O$ stars observed at $O$ inclinations"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f085e9f-63dc-4b0b-9a91-dfec33d2bbe1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"from jax import jit, numpy as jnp"
]
},
{
"cell_type": "markdown",
"id": "7fdc7b1b-6d7a-461f-b70b-d8c6be602aae",
"metadata": {},
"source": [
"time axis is in units of rotational phase:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "009bb930-9eac-4379-ac05-3a2e45a2951c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"phase = np.linspace(0, 2*np.pi, 150)"
]
},
{
"cell_type": "markdown",
"id": "4ada5f5d-b749-47ce-8fb8-e29cd0bccf6d",
"metadata": {},
"source": [
"### Blackbody stellar photosphere and active region spectrum"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63f74b69-45aa-4b84-b4dc-c6dbf7e7c679",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import astropy.units as u\n",
"from astropy.modeling.models import BlackBody\n",
"\n",
"wavelength = np.geomspace(0.2, 5, 30) * u.um\n",
"\n",
"# T_phot = 4500 * u.K\n",
"# T_active = 2500 * u.K\n",
"\n",
"T_phot = 2700 * u.K\n",
"T_active = 2500 * u.K\n",
"\n",
"bb_phot = BlackBody(temperature=T_phot)(wavelength) * u.sr\n",
"bb_active = BlackBody(temperature=T_active)(wavelength) * u.sr\n",
"\n",
"contrast = bb_active / bb_phot\n",
"plt.plot(wavelength, contrast)\n",
"plt.gca().set(xlabel='Wavelength [µm]', ylabel='Contrast')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff214683-51e9-4c15-be59-69891c947a51",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from fleck.jax import rotation_model\n",
"\n",
"n_spots = 4\n",
"n_stars = 1\n",
"\n",
"lon = jnp.linspace(0, 2*np.pi, n_spots)\n",
"lat = jnp.linspace(1.5, 2.5, n_spots)\n",
"rad = jnp.linspace(0.2, 0.3, n_spots)\n",
"contrast = jnp.array(contrast)\n",
"\n",
"if n_stars == 1:\n",
" inclination = np.pi / 2\n",
"else:\n",
" inclination = jnp.linspace(0, np.pi/2, n_stars)\n",
"\n",
"lc = rotation_model(phase, lon, lat, rad, contrast, inclination, f0=0)\n",
"\n",
"# plot only the last inclination (equator-on orientation):\n",
"c = plt.pcolormesh(wavelength.value, phase, 1 + lc[..., -1], cmap=plt.cm.Greys_r);\n",
"plt.colorbar(c, label='Relative flux')\n",
"plt.gca().set(\n",
" xlabel='Wavelength [µm]',\n",
" ylabel='Phase [rad]'\n",
")"
]
},
{
"cell_type": "markdown",
"id": "38f120c1-0c32-4a17-9447-5308fefc76ef",
"metadata": {},
"source": [
"### PHOENIX model spectra"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6e5c64fc-bd9a-4ef3-b0b0-d498772099b8",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from expecto import get_spectrum\n",
"\n",
"from fleck.jax import bin_spectrum\n",
"\n",
"phoenix_phot = get_spectrum(T_phot.value, 4.5, cache=True)\n",
"phoenix_active = get_spectrum(T_active.value, 4.5, cache=True)\n",
"\n",
"phoenix_contrast = bin_spectrum(\n",
" phoenix_active / phoenix_phot, \n",
" bins=1_000, \n",
" min=0.5*u.um, \n",
" max=10*u.um\n",
")\n",
"\n",
"plt.plot(wavelength, contrast, label='Blackbody')\n",
"plt.semilogx(phoenix_contrast.wavelength.to_value(u.um), phoenix_contrast.flux, label='PHOENIX')\n",
"plt.legend()\n",
"plt.gca().set(xlabel='Wavelength [µm]', ylabel='Contrast')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "75b3e20d-d4ae-4daf-af35-b5b1784a3dd9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"lc = rotation_model(phase, lon, lat, rad, phoenix_contrast.flux, inclination, f0=0)\n",
"\n",
"# plot only one inclination, normalized by its mean:\n",
"plot_lc = 1 + lc[..., 2]\n",
"plot_lc /= np.nanmean(plot_lc, axis=0)[None, :]\n",
"\n",
"vmin = np.nanmin(plot_lc)\n",
"vmax = np.nanmax(plot_lc)\n",
"\n",
"if vmax - vmin < 1e-4: \n",
" vmin = 0.9\n",
" vmax = 1.1\n",
"\n",
"c = plt.pcolormesh(\n",
" phoenix_contrast.wavelength.to_value(u.um), phase, plot_lc, \n",
" cmap=plt.cm.Greys_r, \n",
" vmin=vmin, vmax=vmax\n",
" \n",
");\n",
"plt.colorbar(c, label='Relative flux')\n",
"plt.gca().set(\n",
" xscale='log',\n",
" xlabel='Wavelength [µm]',\n",
" ylabel='Phase [rad]'\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d775735e-376b-46a6-b086-35766c9577b5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"%%timeit \n",
"\n",
"lc = rotation_model(phase, lon, lat, rad, phoenix_contrast.flux, inclination, f0=0)"
]
},
{
"cell_type": "markdown",
"id": "1c77072e-3c73-4af6-88cc-604d7c461c85",
"metadata": {},
"source": [
"### two temperatures for active regions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c104147d-e8c2-4a06-a18e-8997d97f4220",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"phoenix_phot = get_spectrum(T_phot.value, 4.5, cache=True)\n",
"phoenix_hot = get_spectrum(5000, 4.5, cache=True)\n",
"phoenix_cool = get_spectrum(2300, 4.5, cache=True)\n",
"\n",
"lam_min = 0.8 * u.um\n",
"lam_max = 4 * u.um\n",
"\n",
"phoenix_phot_binned = bin_spectrum(\n",
" phoenix_phot, \n",
" bins=1_000, \n",
" min=lam_min, \n",
" max=lam_max\n",
")\n",
"\n",
"phoenix_contrast_hot = bin_spectrum(\n",
" phoenix_hot / phoenix_phot, \n",
" bins=1_000, \n",
" min=lam_min, \n",
" max=lam_max\n",
")\n",
"\n",
"phoenix_contrast_cool = bin_spectrum(\n",
" phoenix_cool / phoenix_phot, \n",
" bins=1_000, \n",
" min=lam_min, \n",
" max=lam_max\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "57eb0989-9172-40fa-81bc-0eb76d3e6e0e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"plt.semilogy(phoenix_contrast_cool.wavelength.to_value(u.um), phoenix_contrast_cool.flux, label='cool')\n",
"plt.semilogy(phoenix_contrast_hot.wavelength.to_value(u.um), phoenix_contrast_hot.flux, label='hot')\n",
"plt.legend()\n",
"plt.gca().set(xlabel='Wavelength [µm]', ylabel='Contrast')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8aa3691f-47df-4be8-82ad-b93b74910c2d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"n_spots = 4\n",
"n_stars = 1\n",
"\n",
"lon = jnp.linspace(0, 2*np.pi, n_spots)\n",
"lat = jnp.linspace(1.5, 2.5, n_spots)\n",
"rad = jnp.array([0.1, 0.1, 0.015, 0.01])\n",
"\n",
"# here we define one contrast per spot, \n",
"# with two cool and two hot:\n",
"contrast_per_spot = jnp.vstack([\n",
" phoenix_contrast_cool.flux,\n",
" phoenix_contrast_cool.flux,\n",
" phoenix_contrast_hot.flux,\n",
" phoenix_contrast_hot.flux\n",
"])\n",
"\n",
"if n_stars == 1:\n",
" inclination = np.pi / 2\n",
"else:\n",
" inclination = jnp.linspace(0, np.pi/2, n_stars)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17abbb1e-55af-484f-82ee-572004a61247",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from matplotlib.gridspec import GridSpec \n",
"\n",
"lc = rotation_model(phase, lon, lat, rad, contrast_per_spot, inclination, f0=0)\n",
"\n",
"plot_lc = 1 + lc[..., 0]\n",
"\n",
"fig = plt.figure(figsize=(14, 6))\n",
"\n",
"gs = GridSpec(2, 4, figure=fig)\n",
"\n",
"ax_left = fig.add_subplot(gs[:, 0:2])\n",
"ax_time = fig.add_subplot(gs[0, -2:])\n",
"ax_wave = fig.add_subplot(gs[1, -2:])\n",
"\n",
"c = ax_left.pcolormesh(\n",
" phoenix_contrast_cool.wavelength.to_value(u.um), phase, plot_lc, \n",
" cmap=plt.cm.Greys_r, \n",
" # cmap=plt.cm.coolwarm_r,\n",
" \n",
");\n",
"plt.colorbar(c, label='Relative flux', ax=ax_left)\n",
"ax_left.set(\n",
" xlabel='Wavelength [µm]',\n",
" ylabel='Phase [rad]',\n",
" xscale='log',\n",
" xticks=[1, 2, 3],\n",
" xticklabels=[1, 2, 3]\n",
")\n",
"\n",
"pick_wls = [0.845, 1.68, 2.6] * u.um\n",
"for i, pick_wl in enumerate(pick_wls):\n",
" pick_ind = np.argmin(np.abs(pick_wl - phoenix_contrast_cool.wavelength))\n",
" color = f'C{i}'\n",
" ax_time.plot(phase, (1 + lc[:, pick_ind]) / np.nanmean(1 + lc[:, pick_ind]), label=f'{pick_wl}')\n",
" ax_wave.axvline(pick_wl.value, ls='--', color=color)\n",
"\n",
" \n",
" ax_left.axvline(pick_wl.value, color=f'C{i}', ls='--', lw=1)\n",
"\n",
"pick_phases = [2.5, 3.5, 5.6]\n",
"for i, pick_phase in enumerate(pick_phases, start=3):\n",
" pick_phase_ind = np.argmin(np.abs(pick_phase - phase))\n",
" color = f'C{i}'\n",
" ax_wave.semilogx(phoenix_contrast_cool.wavelength.to_value(u.um), 1 + lc[pick_phase_ind], label=f'{pick_phase} rad', color=color)\n",
" ax_time.axvline(pick_phase, ls='--', color=color)\n",
" \n",
" ax_left.axhline(pick_phase, color=color, ls='--', lw=1)\n",
"\n",
"for axis, title in zip([ax_time, ax_wave], ['Wavelength:', 'Phase:']):\n",
" axis.legend(title=title, alignment='left', framealpha=1, loc='lower right')\n",
" \n",
"ax_time.set(\n",
" xlabel='Phase [rad]', ylabel='Flux (relative to\\nmean at $\\lambda$)'\n",
")\n",
"ax_wave.set(\n",
" xlabel='Wavelength [µm]', ylabel='Flux (relative to\\nunspotted spectrum)',\n",
" xticks=[1, 2, 3, 4],\n",
" xticklabels=[1, 2, 3, 4]\n",
")\n",
"fig.tight_layout()\n",
"plt.savefig('sodium.png', bbox_inches='tight', dpi=200)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad36de8a-e754-4951-b346-fdd76d023ca8",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(8, 5))\n",
"for i, phoenix_model in enumerate([phoenix_phot, phoenix_hot, phoenix_cool]):\n",
" phoenix_model2 = bin_spectrum(phoenix_model, bins=300, log=False, min=lam_min, max=lam_max)\n",
" wl = phoenix_model2.wavelength.to_value(u.um) \n",
"\n",
" ax.semilogy(wl, phoenix_model2.flux, label=f'T={int(phoenix_model.meta[\"PHXTEFF\"])} K')\n",
"ax.legend()\n",
"ax.set(\n",
" xlim=ax_left.get_xlim(), \n",
" xlabel='Wavelength [µm]',\n",
" ylabel=f'Spectral radiance [{phoenix_model2.flux.unit.to_string(\"latex\")}]',\n",
")\n",
"fig.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac37595c-510e-4cec-910c-20c2062b5695",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "09247ef0-9e83-44ef-aa35-94d211d33582",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment