Last active
January 16, 2021 21:08
-
-
Save Miladiouss/30c6a90da3243eafe19fb02acf0747b1 to your computer and use it in GitHub Desktop.
Easy astronomy FITS file handling for Python. It includes easy cutout and save function. Additionally, a percentile normalization method is provided which is ideal for scaling FITS files to better visualization (similar to MinMax of DS9).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from pathlib import Path | |
# AstroPy | |
import astropy | |
from astropy.coordinates import SkyCoord, GCRS, ICRS, GeocentricTrueEcliptic, Galactic | |
import astropy.units as u | |
from astropy.io import fits | |
from astropy.wcs import WCS | |
from astropy.nddata import Cutout2D | |
def sum2mag(pix_sum, zero_point = 27.00): | |
return -2.5 * np.log10(pix_sum) + zero_point | |
class FITS(): | |
def __init__(self, fits_path, band = None, ext_idx = 1): | |
self.path = fits_path | |
self.band = band | |
self.file = fits.open(fits_path) | |
self.ext_idx = ext_idx | |
self.data = self.file[self.ext_idx].data | |
self.wcs = WCS(self.file[self.ext_idx].header) | |
self.dim = self.data.shape | |
def cutout(self, pos = (50, 50), size = (64, 64), save_to_path = None, scale_func = None, dtype = None, overwrite = True): | |
""" | |
pos can be in pixels (e.g. (50, 50)) or in RA and Dec (e.g. SkyCoord(+1.2, -3.4, unit="deg", frame="fk5")) | |
dtype: 'float32' or 'uint8' | |
""" | |
# Crop | |
if size: | |
cutout = Cutout2D( | |
data=self.data, | |
position = pos, #SkyCoord(row['ra'], row['dec'], unit="deg", frame="fk5"), | |
size = size, | |
wcs = self.wcs | |
) | |
output = cutout.data | |
# Add coord info to the header | |
output_wcs = cutout.wcs | |
else: | |
output = self.data | |
# Add coord info to the header | |
output_wcs = self.wcs | |
if scale_func is None: | |
scale_func = lambda x:x | |
output = scale_func(output) | |
# Change dtype | |
if dtype is None: | |
dtype = self.data.dtype | |
output = output.astype(np.dtype(dtype)) | |
if save_to_path: | |
if save_to_path.suffix in ('.fits'): | |
# Create a new fits object and add cutout data | |
output_file = self.file # fits.PrimaryHDU() | |
output_file[self.ext_idx].data = output | |
# Update new FITS coords | |
output_file[self.ext_idx].header.update(output_wcs.to_header()) | |
# Save the new fits file | |
output_file.writeto(save_to_path, overwrite=overwrite) | |
elif save_to_path.suffix in ('.jpeg', '.jpg', '.png'): | |
imsave(save_to_path, output) | |
else: | |
raise Exception('The file extension must be one of the following: .fits, .jpeg, .jpg, .png. "{}" was given instead.'.format(save_to_path.suffix)) | |
return output | |
def close(self): | |
self.file.close() | |
def reduce2uint8(self, output_path, p_low, p_high, overwrite=True): | |
""" | |
Converts FITS to an png-like FITS and saves it to output_path. | |
""" | |
# Convert data | |
output_data = self.data | |
output_data = percentile_normalization(output_data, p_low_feed = p_low, p_high_feed = p_high, scale_coef = 255) | |
output_data = output_data.astype(np.dtype('uint8')) | |
# setup header with correct WCS | |
header = fits.getheader(self.path, self.ext_idx) | |
header.update(self.wcs.to_header()) | |
header.p_low = p_low | |
header.p_high = p_high | |
header.scale_coef = 255 | |
hdu = fits.PrimaryHDU(data=output_data, header=header) | |
# Save | |
hdul = fits.HDUList([hdu]) | |
hdul.writeto(output_path, overwrite=overwrite) | |
def percentile_normalization(data, percentile_low = 1.5, percentile_high = 1.5, p_low_feed = None, p_high_feed = None, scale_coef = 1): | |
p_low = np.percentile(data, percentile_low) | |
p_high = np.percentile(data, 100 - percentile_high) | |
# Artificially set p_low and p_high | |
if p_low_feed: | |
p_low = p_low_feed | |
if p_high_feed: | |
p_high = p_high_feed | |
# Bound values between q_min and q_max | |
normalized = np.clip(data, p_low, p_high) | |
# Shift the zero to prevent negative vlaues | |
normalized = normalized - np.min(normalized) | |
# Normalize so the max is 1 | |
normalized /= np.max(normalized) | |
# Scale | |
normalized *= scale_coef | |
return normalized | |
# ================================= Example 1 ================================= | |
# Read a file and visualize a cutout | |
# x = FITS('path/to/file.fits') | |
# sf = lambda data: percentile_normalization(data, percentile_high=1., percentile_low=30) | |
# d = x.cutout(pos=(92, 110), scale_func=sf) | |
# ================================= Example 2 ================================= | |
# # Example of reducing an HSC-PDR2 float32 FITS file to uint8 | |
# from pathlib import Path | |
# from FITS_Handler import FITS, percentile_normalization | |
# # Define low and high percentile values for each filter | |
# i_low = -0.481 | |
# i_high = 1.900 | |
# r_low = -0.280 | |
# r_high = 1.613 | |
# g_low = -0.175 | |
# g_high = 0.761 | |
# # Define input and output paths | |
# inPath = Path('/HSC-Drive/HSC-PDR2/tracts/8279/calexp-HSC-R-8279-5,4.fits') | |
# outPath = Path('uint8_' + inPath.name) | |
# # Read, reduce, and close | |
# r = FITS(inPath) | |
# r.reduce2uint8(outPath, p_low=r_low, p_high=r_high) | |
# r.close() | |
# # Print sizes | |
# print("Original Size: {:3.0f} MB".format(inPath.stat().st_size / 1e6)) | |
# print("Output Size: {:3.0f} MB".format(outPath.stat().st_size / 1e6)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment