Created
March 28, 2024 08:00
-
-
Save Sunmish/9afa5fdde8b6754eb6e5cbc8660ae311 to your computer and use it in GitHub Desktop.
Basic WCS axes plotting.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python | |
import os | |
import numpy as np | |
from astropy.io import fits | |
from astropy.wcs import WCS | |
from astropy.visualization import ZScaleInterval, AsymmetricPercentileInterval, simple_norm | |
from astropy.wcs.utils import proj_plane_pixel_scales | |
from astropy.visualization.wcsaxes import SphericalCircle | |
from astropy.table import Table | |
from matplotlib import pyplot as plt | |
import matplotlib as mpl | |
from matplotlib.gridspec import GridSpec, SubplotSpec | |
from mpl_toolkits.axes_grid1.anchored_artists import (AnchoredEllipse, | |
AnchoredSizeBar) | |
from matplotlib.patches import Ellipse | |
from matplotlib import rc | |
from matplotlib.font_manager import FontProperties | |
from regions import Regions | |
import cmasher as cmr | |
LOCATIONS = { | |
'upper right' : 1, | |
'upper left' : 2, | |
'lower left' : 3, | |
'lower right' : 4, | |
'right' : 5, | |
'center left' : 6, | |
'center right' : 7, | |
'lower center' : 8, | |
'upper center' : 9, | |
'center' : 10 | |
} | |
def lut_to_cmap(lut_file, divide_by_255=True): | |
'''Convert three column (rgb) table to mpl colormap.''' | |
lut = np.genfromtxt(lut_file) | |
if divide_by_255: | |
cm = mpl.colors.ListedColormap(lut/255.0) | |
else: | |
cm = mpl.colors.ListedColormap(lut) | |
return cm | |
try: | |
sls = lut_to_cmap(os.environ["DROPBOX"] + "/scripts/sls.txt", False) | |
except KeyError: | |
sls = None | |
def show_bdsf_catalogue(catalogue, ax, color, | |
do_ellipses=False, | |
marker="o", | |
markersize=150): | |
table = Table.read(catalogue) | |
for i in range(len(table)): | |
if do_ellipses: | |
e = Ellipse((table["RA"][i], table["DEC"][i]), | |
width=table["Maj"][i], | |
height=table["Min"][i], | |
angle=table["PA"][i], | |
edgecolor=color, | |
facecolor="none", | |
transform=ax.get_transform("fk5") | |
) | |
ax.add_patch(e) | |
else: | |
ax.scatter(table["RA"][i], table["DEC"][i], | |
s=markersize, | |
marker=marker, | |
color=color, | |
transform=ax.get_transform("fk5") | |
) | |
return ax | |
def get_axes_from_gs(gs, fig, N): | |
axes = [] | |
for i in range(N): | |
sb = SubplotSpec(gs, i) | |
sp = sb.get_position(figure=fig).get_points().flatten() | |
x = sp[0] | |
y = sp[1] | |
dx = sp[2]-x | |
dy = sp[3]-y | |
axes.append([x, y, dx, dy]) | |
return axes | |
def get_last2d(array): | |
"""https://stackoverflow.com/a/27111239""" | |
if array.ndim <= 2: | |
return array | |
else: | |
slc = [0] * (array.ndim - 2) | |
slc += [slice(None), slice(None)] | |
return array[tuple(slc)] | |
def auto_v(pmin, pmax, data): | |
"""Determine vmin and vmax from AsymmetricPercentileInterval. | |
A mirror of aplpy's old auto_v function. | |
""" | |
interval = AsymmetricPercentileInterval(pmin, pmax) | |
vmin, vmax = interval.get_limits(data) | |
vmin = -0.1 * (vmax - vmin) + vmin | |
vmax = 0.1 * (vmax - vmin) + vmax | |
return vmin, vmax | |
def recenter(s1, wcs, centre, fov, figsize, axes): | |
xpix, ypix = wcs.all_world2pix(centre[0], centre[1], 0) | |
rxy = (figsize[0]/figsize[1]) * axes[2]/axes[3] | |
print(figsize) | |
print(axes) | |
# rxy = 1. | |
print(rxy) | |
ypix1 = wcs.wcs_world2pix(centre[0], centre[1]+0.5*fov[0], 0)[1] | |
ypix2 = wcs.wcs_world2pix(centre[0], centre[1]-0.5*fov[0], 0)[1] | |
y_range = abs(ypix1 - ypix2) | |
xpix1 = wcs.wcs_world2pix(centre[0]-0.5*fov[0], centre[1], 0)[0] | |
xpix2 = wcs.wcs_world2pix(centre[0]+0.5*fov[0], centre[1], 0)[0] | |
x_range = abs(xpix1 - xpix2) | |
s1.set_ylim([ypix-0.5*y_range, ypix+0.5*y_range]) | |
s1.set_xlim([xpix-0.5*x_range*rxy, xpix+0.5*x_range*rxy]) | |
return s1 | |
def recenter2(s1, wcs, centre, fov): | |
""" | |
https://aplpy.readthedocs.io/en/stable/_modules/aplpy/core.html#FITSFigure.recenter | |
""" | |
xpix, ypix = wcs.all_world2pix(centre[0], centre[1], 0) | |
pix_scale = proj_plane_pixel_scales(wcs) | |
sx, sy = pix_scale[0], pix_scale[1] | |
dx_pix = fov[0] / sx * 0.5 | |
dy_pix = fov[1] / sy * 0.5 | |
s1.set_xlim([xpix - dx_pix, xpix+dx_pix]) | |
s1.set_ylim([ypix - dy_pix, ypix+dy_pix]) | |
return s1 | |
def make_axes(header, fig, | |
ax=None, | |
gs_ax=None, | |
data=None, | |
vsetting=None, | |
psetting=None, | |
scale=1000., | |
cmap="gray", | |
centre=None, | |
fov=None, | |
do_axis_labels=True, | |
fontlabels=14., | |
fontticks=14., | |
do_colorbar=True, | |
colorbar_label=None, | |
colorbar_thickness=0.0075, | |
colorbar_pad=0.0005, | |
colorbar_label_pad=0., | |
colorbar_orientation="vertical", | |
colorbar_label_on_top=False, | |
colorbar_label_on_top_alignment="right", | |
do_beam=True, | |
rotate_labels=False, | |
norm=None, | |
sans=True, | |
tick_direction="in", | |
tick_colour="black", | |
tick_size=8, | |
aspect="auto", | |
): | |
if cmap == "sls": | |
cmap = sls | |
if sans: | |
params = {'text.usetex': False, 'mathtext.fontset': "dejavusans"} | |
plt.rcParams.update(params) | |
wcs = WCS(header).celestial | |
if ax is not None: | |
s1 = plt.axes(ax, projection=wcs) | |
elif gs_ax is not None: | |
s1 = fig.add_subplot(gs_ax, projection=wcs) | |
else: | |
raise RuntimeError("Either `ax` or `gs_ax` should be specified.") | |
figsize = fig.get_size_inches() | |
print(figsize) | |
if data is not None: | |
if norm is None: | |
if vsetting is None: | |
if psetting is None: | |
zscale = ZScaleInterval() | |
vmin, vmax = zscale.get_limits(data) | |
else: | |
vmin, vmax = auto_v(psetting[0], psetting[1], data) | |
vmin *= scale | |
vmax *= scale | |
else: | |
vmin, vmax = vsetting | |
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) | |
im = s1.imshow(data*scale, | |
norm=norm, | |
cmap=cmap, | |
origin="lower", | |
aspect="auto" | |
) | |
if (centre is not None) and (fov is not None) and (figsize is not None): | |
# s1 = recenter(s1, wcs, centre, fov, figsize, ax) | |
s1 = recenter2(s1, wcs, centre, fov) | |
if do_axis_labels: | |
ra = s1.coords[0] | |
dec = s1.coords[1] | |
for n, axis in enumerate([ra, dec]): | |
axis.set_ticks( | |
size=tick_size, | |
direction=tick_direction, | |
color=tick_colour, | |
) | |
axis.display_minor_ticks(True) | |
axis.set_minor_frequency(4) | |
axis.set_ticklabel( | |
color="black", | |
size=fontticks, | |
pad=5., | |
exclude_overlapping=True | |
) | |
axis.tick_params( | |
which="minor", | |
length=tick_size*0.5 | |
) | |
axis.set_auto_axislabel(False) | |
ra.set_axislabel(r"$\alpha_{\mathrm{\mathsf{J2000}}}$", | |
fontsize=fontlabels | |
) | |
dec.set_axislabel(r"$\delta_{\mathrm{\mathsf{J2000}}}$", | |
fontsize=fontlabels | |
) | |
if rotate_labels: | |
dec.set_ticklabel(rotation="vertical") | |
s1.set_aspect(aspect) | |
if do_colorbar: | |
ax = s1.get_position().bounds | |
fig_aspect = figsize[0]/figsize[1] | |
if colorbar_orientation == "horizontal": | |
cbax = [ax[0], ax[1]+ax[3]+colorbar_pad*fig_aspect, ax[2], colorbar_thickness*fig_aspect] | |
label_pad = 7 | |
else: | |
cbax = [ax[2]+ax[0]+colorbar_pad/fig_aspect, ax[1], colorbar_thickness/fig_aspect, ax[3]] | |
label_pad = colorbar_label_pad | |
colorbar_axis = fig.add_axes(cbax) | |
colorbar = mpl.colorbar.ColorbarBase(colorbar_axis, | |
cmap=plt.get_cmap(cmap), | |
norm=norm, | |
orientation=colorbar_orientation | |
) | |
if colorbar_label is None: | |
colorbar_label = r"Stokes $I$ / mJy PSF$^{-1}$" | |
if colorbar_label_on_top: | |
colorbar_axis.set_title(colorbar_label, fontsize=fontlabels, ha=colorbar_label_on_top_alignment) | |
else: | |
colorbar.set_label(colorbar_label, fontsize=fontlabels, labelpad=label_pad) | |
if colorbar_orientation == "horizontal": | |
colorbar.ax.xaxis.set_ticks_position("top") | |
colorbar.ax.xaxis.set_label_position("top") | |
else: | |
colorbar.ax.yaxis.set_ticks_position("right") | |
colorbar.ax.yaxis.set_label_position("right") | |
colorbar.ax.tick_params(which="major", | |
labelsize=fontticks, | |
length=4., | |
direction="out", | |
labelcolor="black" | |
) | |
else: | |
colorbar = None | |
if do_beam: | |
pix_scale = proj_plane_pixel_scales(wcs) | |
sx, sy = pix_scale[0], pix_scale[1] | |
bmaj = header["BMAJ"] | |
bmin = header["BMIN"] | |
bpa = header["BPA"] | |
xypixscale = np.sqrt(sx*sy) | |
bmaj_pix = bmaj / xypixscale | |
bmin_pix = bmin / xypixscale | |
beam = AnchoredEllipse(s1.transData, | |
width=bmin_pix, | |
height=bmaj_pix, | |
angle=bpa, | |
loc="lower left", | |
frameon=True, | |
pad=0.2, | |
borderpad=1. | |
) | |
beam.ellipse.set_edgecolor("black") | |
beam.ellipse.set_facecolor("black") | |
s1.add_artist(beam) | |
return s1, fig, colorbar | |
def show_contours(ax, contour_image, color, levels, | |
linewidth=1.2, | |
linestyle="-", | |
zorder=None): | |
"""Apply contours to an existing axis.""" | |
with fits.open(contour_image) as f: | |
ax.contour( | |
np.squeeze(f[0].data), | |
levels=levels, | |
colors=color, | |
linewidths=linewidth, | |
linestyles=linestyle, | |
transform=ax.get_transform( | |
WCS(f[0].header).celestial | |
), | |
zorder=zorder | |
) | |
return ax | |
def show_beam(ax, wcs, bmaj, bmin, bpa, | |
loc="lower left", | |
color="black"): | |
pix_scale = proj_plane_pixel_scales(wcs) | |
sx, sy = pix_scale[0], pix_scale[1] | |
xypixscale = np.sqrt(sx*sy) | |
bmaj_pix = bmaj / xypixscale | |
bmin_pix = bmin / xypixscale | |
beam = AnchoredEllipse(ax.transData, | |
width=bmin_pix, | |
height=bmaj_pix, | |
angle=bpa, | |
loc=loc, | |
frameon=True, | |
pad=0.5, | |
borderpad=1., | |
snap=False | |
) | |
beam.ellipse.set_edgecolor(color) | |
beam.ellipse.set_facecolor(color) | |
ax.add_artist(beam) | |
return ax | |
def show_scalebar(ax, wcs, scale, label, fontsize, | |
loc="upper right", | |
color="white", | |
frame=False, | |
borderpad=0.4, | |
pad=0.5, | |
size_vertical=0., | |
**kwargs): | |
"""Show linear scale bar on an axis.""" | |
pix_scale = proj_plane_pixel_scales(wcs) | |
sx, sy = pix_scale[0], pix_scale[1], | |
xypixscale = np.sqrt(sx*sy) | |
length = scale / xypixscale | |
scalebar = AnchoredSizeBar( | |
ax.transData, length, label, LOCATIONS[loc], | |
pad=pad, | |
borderpad=borderpad, | |
sep=5, | |
frameon=frame, | |
color=color, | |
# size_vertical=size_vertical, | |
fontproperties=FontProperties(size=fontsize), | |
**kwargs | |
) | |
ax.add_artist(scalebar) | |
return ax, scalebar | |
def show_region(ax, wcs, region_file, color=None): | |
reg = Regions.read(region_file, format="ds9") | |
for r in reg: | |
r1 = r.to_pixel(wcs) | |
if color is None: | |
r1.plot(ax=ax, zorder=100) | |
else: | |
r1.plot(ax=ax, color=color, zorder=100) | |
return ax | |
def show_png(ax, png, | |
vertical_flip=False, | |
horizontal_flip=False, | |
interpolation="nearest"): | |
try: | |
from PIL import Image | |
Image.MAX_IMAGE_PIXELS = None | |
except ImportError: | |
try: | |
import Image | |
except ImportError: | |
raise ImportError("The Python Imaging Library (PIL) is required to read in RGB images") | |
else: | |
image = Image.open(png) | |
if vertical_flip: | |
image = image.transpose(Image.FLIP_TOP_BOTTOM) | |
if horizontal_flip: | |
image = image.transpose(Image.FLIP_LEFT_RIGHT) | |
im = ax.imshow(image, | |
interpolation=interpolation, | |
origin="lower", | |
) | |
return ax |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment