Skip to content

Instantly share code, notes, and snippets.

@Cadair
Last active December 21, 2021 12:21
Show Gist options
  • Save Cadair/9b213b3f3e8d80df00d0c9fa22fe726c to your computer and use it in GitHub Desktop.
Save Cadair/9b213b3f3e8d80df00d0c9fa22fe726c to your computer and use it in GitHub Desktop.
A helper class for a common celestial transform
import astropy.units as u
from astropy.modeling import Model
import astropy.modeling.models as m
class SimpleCelestialTransform(Model):
n_inputs = 2
n_outputs = 2
crpix = Parameter()
cdelt = Parameter()
pci_j = Parameter(default=[[1.0, 0.0], [0.0, 1.0]])
crval = Parameter()
lon_pole = Parameter(default=180)
standard_broadcasting = False
_separable = False
_input_units_allow_dimensionless = True
@property
def input_units(self):
return {"x": u.pix, "y": u.pix}
def __init__(self, *args, projection=m.Pix2Sky_TAN, **kwargs):
super().__init__(*args, **kwargs)
self.inputs = ("x", "y")
self.outputs = ("lon", "lat")
if not issubclass(projection, m.Pix2SkyProjection):
raise TypeError("The projection keyword should be a Pix2SkyProjection model class.")
self.projection = projection()
def _generate_transform(self,
crpix,
cdelt,
pci_j,
crval,
lon_pole):
# Make translation unitful if all parameters have units
translation = (0, 0)
if hasattr(pci_j, "unit"):
translation *= pci_j.unit
shift = m.Shift(-crpix[0]) & m.Shift(-crpix[1])
scale = m.Multiply(cdelt[0]) & m.Multiply(cdelt[1])
rot = m.AffineTransformation2D(pci_j, translation=translation)
skyrot = m.RotateNative2Celestial(crval[0], crval[1], lon_pole)
return shift | scale | rot | self.projection | skyrot
def evaluate(self,
x,
y,
crpix,
cdelt,
pci_j,
crval,
lon_pole):
celestial = self._generate_transform(crpix[0],
cdelt[0],
pci_j[0],
crval[0],
lon_pole)
return celestial(x, y)
@property
def inverse(self):
celestial = self._generate_transform(self.crpix,
self.cdelt,
self.pci_j,
self.crval,
self.lon_pole)
return celestial.inverse
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment