Skip to content

Instantly share code, notes, and snippets.

@Cadair
Last active December 22, 2021 15:47
Show Gist options
  • Save Cadair/a82867a80fb743bfa7c4f3dd3a82d630 to your computer and use it in GitHub Desktop.
Save Cadair/a82867a80fb743bfa7c4f3dd3a82d630 to your computer and use it in GitHub Desktop.
import numpy as np
import astropy.units as u
from astropy.modeling import Model, CompoundModel
import astropy.modeling.models as m
from astropy.coordinates.matrix_utilities import rotation_matrix
from gwcs.utils import create_projection_transform
from models import *
# crpix1u, crpix2u = (1024, 1024) * u.pix
# crval1u, crval2u = (0, 0) * u.arcsec
# cdelt1u, cdelt2u = (1, 1) * u.arcsec / u.pix
# pcu = np.identity(2) * u.arcsec
# shiftu = m.Shift(-crpix1u) & m.Shift(-crpix2u)
# scaleu = m.Multiply(cdelt1u) & m.Multiply(cdelt2u)
# vrot = VaryingAffineTransformation2D()
# tanu = m.Pix2Sky_TAN()
# native = shiftu | scaleu
# skyrotu = VaryingRotateNative2Celestial(180 * u.deg)
# linear_pointing_shift1 = m.Linear1D(slope=10*u.arcsec / u.pix, intercept=crval1u)
# linear_pointing_shift2 = m.Linear1D(slope=20*u.arcsec / u.pix, intercept=crval2u)
# crval = m.Mapping((0, 0)) | (linear_pointing_shift1 & linear_pointing_shift2)
# rmlt = RotationMatrixLookup(lookup_table=varying_matrix_lt)
# forward = m.Mapping((0, 1, 2, 2, 2)) | ((((native & rmlt) | vrot | tanu) & crval) | skyrotu) & time
varying_matrix_lt = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 10)]
time = m.Linear1D(slope=5*u.s/u.pix, intercept=0*u.s)
# sct = SimpleCelestialTransform(crpix=(0, 0),
# crval=(0, 0),
# cdelt=(1, 1),
# pc=np.identity(2))
# print(sct(0,0))
# sctu = SimpleCelestialTransform(crpix=(0, 0)*u.pix,
# crval=(0, 0)*u.arcsec,
# cdelt=(1, 1)*u.arcsec/u.pix,
# pc=np.identity(2)*u.arcsec,
# lon_pole=180*u.deg)
# print(sctu(0*u.pix,0*u.pix))
# print(sctu(0,0))
# d3 = sctu & time
# vct = VaryingCelestialTransform(crpix=(0, 0),
# cdelt=(1, 1),
# crval_table=np.array((np.linspace(0, 3, 10),
# np.linspace(-1, 2, 10))).T,
# pc_table=varying_matrix_lt,
# lon_pole=180)
# print(vct)
# world = vct(1, 1, 4)
# pixel = vct.inverse(*world, 4)
# print(world, pixel)
vct = VaryingCelestialTransform(crpix=(0, 0)*u.pix,
cdelt=(1, 1)*u.arcsec/u.pix,
crval_table=(0, 0)*u.arcsec,
pc_table=varying_matrix_lt * u.arcsec,
lon_pole=180*u.deg)
print(vct)
world = vct(1*u.pix, 1*u.pix, 4*u.pix)
pixel = vct.inverse(*world, 4*u.pix)
print(world, pixel)
print()
print()
print()
# We can verify the forward transform with this
simple_forward = m.Mapping((0, 1, 2, 2)) | (vct & time)
world = simple_forward(1*u.pix, 1*u.pix, 4*u.pix)
print(world)
# Now we build an instance of the Coupled transform
ci = CoupledCompoundModel("&", vct, time, shared_inputs=1)
print(ci)
print(ci.inverse)
world = ci(1*u.pix, 1*u.pix, 4*u.pix)
print(world)
pixel = ci.inverse(*world)
print(pixel)
import numpy as np
import astropy.units as u
import astropy.modeling.models as m
from astropy.modeling import Model, Parameter, CompoundModel
from astropy.modeling.rotations import _EulerRotation, _to_orig_unit, _to_radian
class VaryingSkyRotation(_EulerRotation, Model):
lon_pole = Parameter(default=0,
getter=_to_orig_unit,
setter=_to_radian,
description="Longitude of a pole")
def __init__(self, lon_pole, **kwargs):
super().__init__(lon_pole, **kwargs)
self.axes_order = 'zxz'
def _evaluate(self, phi, theta, lon, lat, lon_pole):
alpha, delta = super().evaluate(phi, theta, lon, lat, lon_pole,
self.axes_order)
mask = alpha < 0
if isinstance(mask, np.ndarray):
alpha[mask] += 360
else:
alpha += 360
return alpha, delta
class VaryingRotateNative2Celestial(VaryingSkyRotation):
n_inputs = 4
n_outputs = 2
@property
def input_units(self):
""" Input units. """
return {self.inputs[0]: u.deg,
self.inputs[1]: u.deg,
self.inputs[2]: u.deg,
self.inputs[3]: u.deg}
@property
def return_units(self):
""" Output units. """
return {self.outputs[0]: u.deg, self.outputs[1]: u.deg}
def __init__(self, lon_pole, **kwargs):
super().__init__(lon_pole, **kwargs)
self.inputs = ('phi_N', 'theta_N', 'lon', 'lat')
self.outputs = ('alpha_C', 'delta_C')
def evaluate(self, phi_N, theta_N, lon, lat, lon_pole):
"""
Parameters
----------
phi_N, theta_N : float or `~astropy.units.Quantity` ['angle']
Angles in the Native coordinate system.
it is assumed that numerical only inputs are in degrees.
If float, assumed in degrees.
lon, lat, lon_pole : float or `~astropy.units.Quantity` ['angle']
Parameter values when the model was initialized.
If float, assumed in degrees.
Returns
-------
alpha_C, delta_C : float or `~astropy.units.Quantity` ['angle']
Angles on the Celestial sphere.
If float, in degrees.
"""
print(lon, lat)
# The values are in radians since they have already been through the
# setter.
if isinstance(lon, u.Quantity):
lon = lon.value
lat = lat.value
lon_pole = lon_pole.value
# Convert to Euler angles
phi = lon_pole - np.pi / 2
theta = - (np.pi / 2 - lat)
psi = -(np.pi / 2 + lon)
alpha_C, delta_C = self._evaluate(phi_N, theta_N, phi, theta, psi)
return alpha_C, delta_C
class RotationMatrixLookup(Model):
"""
Lookup the appropriate rotation matrix for the given input.
"""
n_inputs = 1
n_outputs = 1
lookup_table = Parameter(fixed=True,
description="A series of rotation matrices")
@classmethod
def evaluate(self, x, lookup_table):
# I don't know why both inputs are getting a bonus dimension
x = np.array(np.round(x), dtype=int)[0]
lookup_table = lookup_table[0]
print(lookup_table[x])
return lookup_table[x]
class VaryingAffineTransformation2D(Model):
"""
Perform an affine transformation in 2 dimensions with the matrix as an input.
For simplicity this model doesn't provide a way of setting a translation.
"""
n_inputs = 3
n_outputs = 2
standard_broadcasting = False
_separable = False
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.inputs = ("x", "y", "matrix")
self.outputs = ("x", "y")
@classmethod
def evaluate(cls, x, y, matrix):
"""
Apply the transformation to a set of 2D Cartesian coordinates given as
two lists--one for the x coordinates and one for a y coordinates--or a
single coordinate pair.
Parameters
----------
x, y : array, float
x and y coordinates
"""
translation = [0, 0]
if hasattr(matrix, "unit"):
translation *= matrix.unit
return m.AffineTransformation2D.evaluate(x, y, matrix, translation)
@property
def input_units(self):
return None
class SimpleCelestialTransform(Model):
n_inputs = 2
n_outputs = 2
crpix = Parameter()
cdelt = Parameter()
pc = Parameter(default=[[1.0, 0.0], [0.0, 1.0]])
crval = Parameter()
lon_pole = Parameter(default=180)
standard_broadcasting = False
_separable = False
_input_units_allow_dimensionless = True
@property
def input_units(self):
return {"x": u.pix, "y": u.pix}
def __init__(self, *args, projection=m.Pix2Sky_TAN(), **kwargs):
super().__init__(*args, **kwargs)
if type(self) is SimpleCelestialTransform:
self.inputs = ("x", "y")
self.outputs = ("lon", "lat")
if not isinstance(projection, m.Pix2SkyProjection):
raise TypeError("The projection keyword should be a Pix2SkyProjection model class.")
self.projection = projection
def _generate_transform(self,
crpix,
cdelt,
pc,
crval,
lon_pole):
# Make translation unitful if all parameters have units
translation = (0, 0)
if hasattr(pc, "unit") and pc.unit is not None:
translation *= pc.unit
# If we have units then we need to convert all things to Quantity
# as they might be Parameter classes
crpix = u.Quantity(crpix)
cdelt = u.Quantity(cdelt)
crval = u.Quantity(crval)
lon_pole = u.Quantity(lon_pole)
pc = u.Quantity(pc)
shift = m.Shift(-crpix[0]) & m.Shift(-crpix[1])
scale = m.Multiply(cdelt[0]) & m.Multiply(cdelt[1])
rot = m.AffineTransformation2D(pc, translation=translation)
skyrot = m.RotateNative2Celestial(crval[0], crval[1], lon_pole)
return shift | scale | rot | self.projection | skyrot
def evaluate(self,
x,
y,
crpix,
cdelt,
pc,
crval,
lon_pole):
celestial = self._generate_transform(crpix[0],
cdelt[0],
pc[0],
crval[0],
lon_pole)
return celestial(x, y)
@property
def inverse(self):
celestial = self._generate_transform(self.crpix,
self.cdelt,
self.pc,
self.crval,
self.lon_pole)
return celestial.inverse
class BaseVaryingCelestialTransform(Model):
standard_broadcasting = False
_separable = False
_input_units_allow_dimensionless = True
crpix = Parameter()
cdelt = Parameter()
lon_pole = Parameter(default=180)
@staticmethod
def _validate_table_shapes(pc_table, crval_table):
table_shape = None
if pc_table.shape != (2, 2):
if pc_table.shape[-2:] != (2, 2):
raise ValueError("The pc table should be an array of 2x2 matrices.")
table_shape = pc_table.shape[:-2]
if crval_table.shape != (2,):
if crval_table.shape[-1] != 2:
raise ValueError("The crval table should be an array of coordinate "
"pairs (the last dimension should have length 2).")
if table_shape is not None:
if table_shape != crval_table.shape[:-1]:
raise ValueError("The shape of the pc and crval tables should match. "
f"The pc table has shape {table_shape} and the "
f"crval table has shape {crval_table.shape[:-1]}")
table_shape = crval_table.shape[:-1]
return table_shape
@staticmethod
def get_pc_crval(z, pc, crval):
# Get crval and pc
if isinstance(z, u.Quantity):
ind = int(z.value)
else:
ind = int(z)
if pc.shape != (2, 2):
pc = pc[ind]
if crval.shape != (2,):
crval = crval[ind]
return pc, crval
def __init__(self, *args, crval_table=None, pc_table=None, projection=m.Pix2Sky_TAN(), **kwargs):
super().__init__(*args, **kwargs)
self.pc_table = np.asanyarray(pc_table)
self.crval_table = np.asanyarray(crval_table)
self.table_shape = self._validate_table_shapes(self.pc_table, self.crval_table)
if not isinstance(projection, m.Pix2SkyProjection):
raise TypeError("The projection keyword should be a Pix2SkyProjection model class.")
self.projection = projection
class VaryingCelestialTransform(BaseVaryingCelestialTransform):
n_inputs = 3
n_outputs = 2
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.inputs = ("x", "y", "z")
self.outputs = ("lon", "lat")
@property
def input_units(self):
return {"x": u.pix, "y": u.pix, "z": u.pix}
def transform_at_index(self, z):
pc, crval = self.get_pc_crval(z, self.pc_table, self.crval_table)
return SimpleCelestialTransform(crpix=self.crpix,
cdelt=self.cdelt,
pc=pc,
crval=crval,
lon_pole=self.lon_pole,
projection=self.projection)
def evaluate(self, x, y, z, crpix, cdelt, lon_pole):
pc, crval = self.get_pc_crval(z, self.pc_table, self.crval_table)
sct = SimpleCelestialTransform(crpix=crpix[0],
cdelt=cdelt[0],
pc=pc,
crval=crval,
lon_pole=lon_pole[0],
projection=self.projection)
return sct(x, y)
@property
def inverse(self):
ivct = InverseVaryingCelestialTransform(crpix=self.crpix,
cdelt=self.cdelt,
lon_pole=self.lon_pole,
pc_table=self.pc_table,
crval_table=self.crval_table,
projection=self.projection)
return ivct
class InverseVaryingCelestialTransform(BaseVaryingCelestialTransform):
n_inputs = 3
n_outputs = 2
@property
def input_units(self):
return {"lon": u.deg, "lat": u.deg, "z": u.pix}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.inputs = ("lon", "lat", "z")
self.outputs = ("x", "y")
def evaluate(self, lon, lat, z, crpix, cdelt, lon_pole, **kwargs):
pc, crval = self.get_pc_crval(z,
self.pc_table,
self.crval_table)
sct = SimpleCelestialTransform(crpix=crpix[0],
cdelt=cdelt[0],
pc=pc,
crval=crval,
lon_pole=lon_pole[0],
projection=self.projection)
return sct.inverse(lon, lat)
class CoupledCompoundModel(CompoundModel):
"""
This class takes two models which share one or more inputs on the forward
transform, and where the left hand model's inverse is dependent on the
output of the right hand model's inverse output.
Parameters
----------
op : `str`
The operator to use, can only be ``'&'``.
left : `astropy.modeling.Model`
The left hand model, should have one or more inputs which are shared
with the right hand model on the forward transform, and also rely on
these inputs for the inverse transform.
right : `astropy.modeling.Model`
The right hand model, no special behaviour is required here.
shared_inputs : `int`
The number of inputs (counted from the end of the end of the inputs to
the left model and the start of the inputs to the right model) which
are shared between the two models.
Example
-------
Take the following example with a time dependent celestial transform
(modelled as dependent upon the pixel coordinate for time rather than the
world coordinate).
The forward transform uses the "z" pixel dimension as input to both the
Celestial and Temporal models, this leads to the following transform in the
forward direction:
x y z
│ │ │
│ │ ┌────────┤
│ │ │ │
▼ ▼ ▼ ▼
┌─────────┐ ┌────────┐
│Celestial│ │Temporal│
└─┬───┬───┘ └───┬────┘
│ │ │
│ │ │
│ │ │
▼ ▼ ▼
lon lat time
This could trivially be reproduced using `~astropy.modeling.models.Mapping`.
The complexity is in the reverse transform, where the inverse Celestial
transform is also dependent upon the pixel coordinate z.
This means that the output of the inverse Temporal transform has to be
duplicated as an input to the Celestial transform's inverse.
This is achieved by the use of the ``Mapping`` models in
``CoupledCompoundModel.inverse`` to create a multi-stage compound model
which duplicates the output of the right hand side model.
lon lat time
│ │ │
│ │ ▼
│ │ ┌─────────┐
│ │ │Temporal'│
│ │ └──┬──┬───┘
│ │ z │ │
│ │ ┌─────┘ │
│ │ │ │
▼ ▼ ▼ │
┌──────────┐ │
│Celestial'│ │
└─┬───┬────┘ │
│ │ │
▼ ▼ ▼
x y z
"""
_separable = False
def __init__(self, op, left, right, name=None, shared_inputs=1):
if op != "&":
raise ValueError(
f"The {self.__class__.__name__} class should only be used with the & operator."
)
super().__init__(op, left, right, name=name)
self.n_inputs = self.n_inputs - shared_inputs
self.inputs = self.inputs[:-shared_inputs]
self.shared_inputs = shared_inputs
def _evaluate(self, *args, **kw):
leftval = self.left(*(args[:self.left.n_inputs]), **kw)
rightval = self.right(*(args[-self.right.n_inputs:]), **kw)
return self._apply_operators_to_value_lists(leftval, rightval, **kw)
@property
def inverse(self):
left_inverse = self.left.inverse
right_inverse = self.right.inverse
total_inputs = self.n_outputs
n_left_only_inputs = total_inputs - self.shared_inputs
# Pass through arguments to the left model unchanged while computing the right output
mapping = list(range(n_left_only_inputs))
step1 = m.Mapping(mapping) & right_inverse
# Now pass through the right outputs unchanged while also feeding them into the left model
# This mapping duplicates the output of the right inverse to be fed
# into the left and also out unmodified at the end of the transform
inter_mapping = mapping + list(range(max(mapping) + 1, max(mapping) + 1 + right_inverse.n_outputs)) * 2
step2 = m.Mapping(inter_mapping) | (left_inverse & m.Mapping(list(range(right_inverse.n_outputs))))
return step1 | step2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment