Last active
December 22, 2021 15:47
-
-
Save Cadair/a82867a80fb743bfa7c4f3dd3a82d630 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import astropy.units as u | |
| from astropy.modeling import Model, CompoundModel | |
| import astropy.modeling.models as m | |
| from astropy.coordinates.matrix_utilities import rotation_matrix | |
| from gwcs.utils import create_projection_transform | |
| from models import * | |
| # crpix1u, crpix2u = (1024, 1024) * u.pix | |
| # crval1u, crval2u = (0, 0) * u.arcsec | |
| # cdelt1u, cdelt2u = (1, 1) * u.arcsec / u.pix | |
| # pcu = np.identity(2) * u.arcsec | |
| # shiftu = m.Shift(-crpix1u) & m.Shift(-crpix2u) | |
| # scaleu = m.Multiply(cdelt1u) & m.Multiply(cdelt2u) | |
| # vrot = VaryingAffineTransformation2D() | |
| # tanu = m.Pix2Sky_TAN() | |
| # native = shiftu | scaleu | |
| # skyrotu = VaryingRotateNative2Celestial(180 * u.deg) | |
| # linear_pointing_shift1 = m.Linear1D(slope=10*u.arcsec / u.pix, intercept=crval1u) | |
| # linear_pointing_shift2 = m.Linear1D(slope=20*u.arcsec / u.pix, intercept=crval2u) | |
| # crval = m.Mapping((0, 0)) | (linear_pointing_shift1 & linear_pointing_shift2) | |
| # rmlt = RotationMatrixLookup(lookup_table=varying_matrix_lt) | |
| # forward = m.Mapping((0, 1, 2, 2, 2)) | ((((native & rmlt) | vrot | tanu) & crval) | skyrotu) & time | |
| varying_matrix_lt = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 10)] | |
| time = m.Linear1D(slope=5*u.s/u.pix, intercept=0*u.s) | |
| # sct = SimpleCelestialTransform(crpix=(0, 0), | |
| # crval=(0, 0), | |
| # cdelt=(1, 1), | |
| # pc=np.identity(2)) | |
| # print(sct(0,0)) | |
| # sctu = SimpleCelestialTransform(crpix=(0, 0)*u.pix, | |
| # crval=(0, 0)*u.arcsec, | |
| # cdelt=(1, 1)*u.arcsec/u.pix, | |
| # pc=np.identity(2)*u.arcsec, | |
| # lon_pole=180*u.deg) | |
| # print(sctu(0*u.pix,0*u.pix)) | |
| # print(sctu(0,0)) | |
| # d3 = sctu & time | |
| # vct = VaryingCelestialTransform(crpix=(0, 0), | |
| # cdelt=(1, 1), | |
| # crval_table=np.array((np.linspace(0, 3, 10), | |
| # np.linspace(-1, 2, 10))).T, | |
| # pc_table=varying_matrix_lt, | |
| # lon_pole=180) | |
| # print(vct) | |
| # world = vct(1, 1, 4) | |
| # pixel = vct.inverse(*world, 4) | |
| # print(world, pixel) | |
| vct = VaryingCelestialTransform(crpix=(0, 0)*u.pix, | |
| cdelt=(1, 1)*u.arcsec/u.pix, | |
| crval_table=(0, 0)*u.arcsec, | |
| pc_table=varying_matrix_lt * u.arcsec, | |
| lon_pole=180*u.deg) | |
| print(vct) | |
| world = vct(1*u.pix, 1*u.pix, 4*u.pix) | |
| pixel = vct.inverse(*world, 4*u.pix) | |
| print(world, pixel) | |
| print() | |
| print() | |
| print() | |
| # We can verify the forward transform with this | |
| simple_forward = m.Mapping((0, 1, 2, 2)) | (vct & time) | |
| world = simple_forward(1*u.pix, 1*u.pix, 4*u.pix) | |
| print(world) | |
| # Now we build an instance of the Coupled transform | |
| ci = CoupledCompoundModel("&", vct, time, shared_inputs=1) | |
| print(ci) | |
| print(ci.inverse) | |
| world = ci(1*u.pix, 1*u.pix, 4*u.pix) | |
| print(world) | |
| pixel = ci.inverse(*world) | |
| print(pixel) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import astropy.units as u | |
| import astropy.modeling.models as m | |
| from astropy.modeling import Model, Parameter, CompoundModel | |
| from astropy.modeling.rotations import _EulerRotation, _to_orig_unit, _to_radian | |
| class VaryingSkyRotation(_EulerRotation, Model): | |
| lon_pole = Parameter(default=0, | |
| getter=_to_orig_unit, | |
| setter=_to_radian, | |
| description="Longitude of a pole") | |
| def __init__(self, lon_pole, **kwargs): | |
| super().__init__(lon_pole, **kwargs) | |
| self.axes_order = 'zxz' | |
| def _evaluate(self, phi, theta, lon, lat, lon_pole): | |
| alpha, delta = super().evaluate(phi, theta, lon, lat, lon_pole, | |
| self.axes_order) | |
| mask = alpha < 0 | |
| if isinstance(mask, np.ndarray): | |
| alpha[mask] += 360 | |
| else: | |
| alpha += 360 | |
| return alpha, delta | |
| class VaryingRotateNative2Celestial(VaryingSkyRotation): | |
| n_inputs = 4 | |
| n_outputs = 2 | |
| @property | |
| def input_units(self): | |
| """ Input units. """ | |
| return {self.inputs[0]: u.deg, | |
| self.inputs[1]: u.deg, | |
| self.inputs[2]: u.deg, | |
| self.inputs[3]: u.deg} | |
| @property | |
| def return_units(self): | |
| """ Output units. """ | |
| return {self.outputs[0]: u.deg, self.outputs[1]: u.deg} | |
| def __init__(self, lon_pole, **kwargs): | |
| super().__init__(lon_pole, **kwargs) | |
| self.inputs = ('phi_N', 'theta_N', 'lon', 'lat') | |
| self.outputs = ('alpha_C', 'delta_C') | |
| def evaluate(self, phi_N, theta_N, lon, lat, lon_pole): | |
| """ | |
| Parameters | |
| ---------- | |
| phi_N, theta_N : float or `~astropy.units.Quantity` ['angle'] | |
| Angles in the Native coordinate system. | |
| it is assumed that numerical only inputs are in degrees. | |
| If float, assumed in degrees. | |
| lon, lat, lon_pole : float or `~astropy.units.Quantity` ['angle'] | |
| Parameter values when the model was initialized. | |
| If float, assumed in degrees. | |
| Returns | |
| ------- | |
| alpha_C, delta_C : float or `~astropy.units.Quantity` ['angle'] | |
| Angles on the Celestial sphere. | |
| If float, in degrees. | |
| """ | |
| print(lon, lat) | |
| # The values are in radians since they have already been through the | |
| # setter. | |
| if isinstance(lon, u.Quantity): | |
| lon = lon.value | |
| lat = lat.value | |
| lon_pole = lon_pole.value | |
| # Convert to Euler angles | |
| phi = lon_pole - np.pi / 2 | |
| theta = - (np.pi / 2 - lat) | |
| psi = -(np.pi / 2 + lon) | |
| alpha_C, delta_C = self._evaluate(phi_N, theta_N, phi, theta, psi) | |
| return alpha_C, delta_C | |
| class RotationMatrixLookup(Model): | |
| """ | |
| Lookup the appropriate rotation matrix for the given input. | |
| """ | |
| n_inputs = 1 | |
| n_outputs = 1 | |
| lookup_table = Parameter(fixed=True, | |
| description="A series of rotation matrices") | |
| @classmethod | |
| def evaluate(self, x, lookup_table): | |
| # I don't know why both inputs are getting a bonus dimension | |
| x = np.array(np.round(x), dtype=int)[0] | |
| lookup_table = lookup_table[0] | |
| print(lookup_table[x]) | |
| return lookup_table[x] | |
| class VaryingAffineTransformation2D(Model): | |
| """ | |
| Perform an affine transformation in 2 dimensions with the matrix as an input. | |
| For simplicity this model doesn't provide a way of setting a translation. | |
| """ | |
| n_inputs = 3 | |
| n_outputs = 2 | |
| standard_broadcasting = False | |
| _separable = False | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.inputs = ("x", "y", "matrix") | |
| self.outputs = ("x", "y") | |
| @classmethod | |
| def evaluate(cls, x, y, matrix): | |
| """ | |
| Apply the transformation to a set of 2D Cartesian coordinates given as | |
| two lists--one for the x coordinates and one for a y coordinates--or a | |
| single coordinate pair. | |
| Parameters | |
| ---------- | |
| x, y : array, float | |
| x and y coordinates | |
| """ | |
| translation = [0, 0] | |
| if hasattr(matrix, "unit"): | |
| translation *= matrix.unit | |
| return m.AffineTransformation2D.evaluate(x, y, matrix, translation) | |
| @property | |
| def input_units(self): | |
| return None | |
| class SimpleCelestialTransform(Model): | |
| n_inputs = 2 | |
| n_outputs = 2 | |
| crpix = Parameter() | |
| cdelt = Parameter() | |
| pc = Parameter(default=[[1.0, 0.0], [0.0, 1.0]]) | |
| crval = Parameter() | |
| lon_pole = Parameter(default=180) | |
| standard_broadcasting = False | |
| _separable = False | |
| _input_units_allow_dimensionless = True | |
| @property | |
| def input_units(self): | |
| return {"x": u.pix, "y": u.pix} | |
| def __init__(self, *args, projection=m.Pix2Sky_TAN(), **kwargs): | |
| super().__init__(*args, **kwargs) | |
| if type(self) is SimpleCelestialTransform: | |
| self.inputs = ("x", "y") | |
| self.outputs = ("lon", "lat") | |
| if not isinstance(projection, m.Pix2SkyProjection): | |
| raise TypeError("The projection keyword should be a Pix2SkyProjection model class.") | |
| self.projection = projection | |
| def _generate_transform(self, | |
| crpix, | |
| cdelt, | |
| pc, | |
| crval, | |
| lon_pole): | |
| # Make translation unitful if all parameters have units | |
| translation = (0, 0) | |
| if hasattr(pc, "unit") and pc.unit is not None: | |
| translation *= pc.unit | |
| # If we have units then we need to convert all things to Quantity | |
| # as they might be Parameter classes | |
| crpix = u.Quantity(crpix) | |
| cdelt = u.Quantity(cdelt) | |
| crval = u.Quantity(crval) | |
| lon_pole = u.Quantity(lon_pole) | |
| pc = u.Quantity(pc) | |
| shift = m.Shift(-crpix[0]) & m.Shift(-crpix[1]) | |
| scale = m.Multiply(cdelt[0]) & m.Multiply(cdelt[1]) | |
| rot = m.AffineTransformation2D(pc, translation=translation) | |
| skyrot = m.RotateNative2Celestial(crval[0], crval[1], lon_pole) | |
| return shift | scale | rot | self.projection | skyrot | |
| def evaluate(self, | |
| x, | |
| y, | |
| crpix, | |
| cdelt, | |
| pc, | |
| crval, | |
| lon_pole): | |
| celestial = self._generate_transform(crpix[0], | |
| cdelt[0], | |
| pc[0], | |
| crval[0], | |
| lon_pole) | |
| return celestial(x, y) | |
| @property | |
| def inverse(self): | |
| celestial = self._generate_transform(self.crpix, | |
| self.cdelt, | |
| self.pc, | |
| self.crval, | |
| self.lon_pole) | |
| return celestial.inverse | |
| class BaseVaryingCelestialTransform(Model): | |
| standard_broadcasting = False | |
| _separable = False | |
| _input_units_allow_dimensionless = True | |
| crpix = Parameter() | |
| cdelt = Parameter() | |
| lon_pole = Parameter(default=180) | |
| @staticmethod | |
| def _validate_table_shapes(pc_table, crval_table): | |
| table_shape = None | |
| if pc_table.shape != (2, 2): | |
| if pc_table.shape[-2:] != (2, 2): | |
| raise ValueError("The pc table should be an array of 2x2 matrices.") | |
| table_shape = pc_table.shape[:-2] | |
| if crval_table.shape != (2,): | |
| if crval_table.shape[-1] != 2: | |
| raise ValueError("The crval table should be an array of coordinate " | |
| "pairs (the last dimension should have length 2).") | |
| if table_shape is not None: | |
| if table_shape != crval_table.shape[:-1]: | |
| raise ValueError("The shape of the pc and crval tables should match. " | |
| f"The pc table has shape {table_shape} and the " | |
| f"crval table has shape {crval_table.shape[:-1]}") | |
| table_shape = crval_table.shape[:-1] | |
| return table_shape | |
| @staticmethod | |
| def get_pc_crval(z, pc, crval): | |
| # Get crval and pc | |
| if isinstance(z, u.Quantity): | |
| ind = int(z.value) | |
| else: | |
| ind = int(z) | |
| if pc.shape != (2, 2): | |
| pc = pc[ind] | |
| if crval.shape != (2,): | |
| crval = crval[ind] | |
| return pc, crval | |
| def __init__(self, *args, crval_table=None, pc_table=None, projection=m.Pix2Sky_TAN(), **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.pc_table = np.asanyarray(pc_table) | |
| self.crval_table = np.asanyarray(crval_table) | |
| self.table_shape = self._validate_table_shapes(self.pc_table, self.crval_table) | |
| if not isinstance(projection, m.Pix2SkyProjection): | |
| raise TypeError("The projection keyword should be a Pix2SkyProjection model class.") | |
| self.projection = projection | |
| class VaryingCelestialTransform(BaseVaryingCelestialTransform): | |
| n_inputs = 3 | |
| n_outputs = 2 | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.inputs = ("x", "y", "z") | |
| self.outputs = ("lon", "lat") | |
| @property | |
| def input_units(self): | |
| return {"x": u.pix, "y": u.pix, "z": u.pix} | |
| def transform_at_index(self, z): | |
| pc, crval = self.get_pc_crval(z, self.pc_table, self.crval_table) | |
| return SimpleCelestialTransform(crpix=self.crpix, | |
| cdelt=self.cdelt, | |
| pc=pc, | |
| crval=crval, | |
| lon_pole=self.lon_pole, | |
| projection=self.projection) | |
| def evaluate(self, x, y, z, crpix, cdelt, lon_pole): | |
| pc, crval = self.get_pc_crval(z, self.pc_table, self.crval_table) | |
| sct = SimpleCelestialTransform(crpix=crpix[0], | |
| cdelt=cdelt[0], | |
| pc=pc, | |
| crval=crval, | |
| lon_pole=lon_pole[0], | |
| projection=self.projection) | |
| return sct(x, y) | |
| @property | |
| def inverse(self): | |
| ivct = InverseVaryingCelestialTransform(crpix=self.crpix, | |
| cdelt=self.cdelt, | |
| lon_pole=self.lon_pole, | |
| pc_table=self.pc_table, | |
| crval_table=self.crval_table, | |
| projection=self.projection) | |
| return ivct | |
| class InverseVaryingCelestialTransform(BaseVaryingCelestialTransform): | |
| n_inputs = 3 | |
| n_outputs = 2 | |
| @property | |
| def input_units(self): | |
| return {"lon": u.deg, "lat": u.deg, "z": u.pix} | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.inputs = ("lon", "lat", "z") | |
| self.outputs = ("x", "y") | |
| def evaluate(self, lon, lat, z, crpix, cdelt, lon_pole, **kwargs): | |
| pc, crval = self.get_pc_crval(z, | |
| self.pc_table, | |
| self.crval_table) | |
| sct = SimpleCelestialTransform(crpix=crpix[0], | |
| cdelt=cdelt[0], | |
| pc=pc, | |
| crval=crval, | |
| lon_pole=lon_pole[0], | |
| projection=self.projection) | |
| return sct.inverse(lon, lat) | |
| class CoupledCompoundModel(CompoundModel): | |
| """ | |
| This class takes two models which share one or more inputs on the forward | |
| transform, and where the left hand model's inverse is dependent on the | |
| output of the right hand model's inverse output. | |
| Parameters | |
| ---------- | |
| op : `str` | |
| The operator to use, can only be ``'&'``. | |
| left : `astropy.modeling.Model` | |
| The left hand model, should have one or more inputs which are shared | |
| with the right hand model on the forward transform, and also rely on | |
| these inputs for the inverse transform. | |
| right : `astropy.modeling.Model` | |
| The right hand model, no special behaviour is required here. | |
| shared_inputs : `int` | |
| The number of inputs (counted from the end of the end of the inputs to | |
| the left model and the start of the inputs to the right model) which | |
| are shared between the two models. | |
| Example | |
| ------- | |
| Take the following example with a time dependent celestial transform | |
| (modelled as dependent upon the pixel coordinate for time rather than the | |
| world coordinate). | |
| The forward transform uses the "z" pixel dimension as input to both the | |
| Celestial and Temporal models, this leads to the following transform in the | |
| forward direction: | |
| x y z | |
| │ │ │ | |
| │ │ ┌────────┤ | |
| │ │ │ │ | |
| ▼ ▼ ▼ ▼ | |
| ┌─────────┐ ┌────────┐ | |
| │Celestial│ │Temporal│ | |
| └─┬───┬───┘ └───┬────┘ | |
| │ │ │ | |
| │ │ │ | |
| │ │ │ | |
| ▼ ▼ ▼ | |
| lon lat time | |
| This could trivially be reproduced using `~astropy.modeling.models.Mapping`. | |
| The complexity is in the reverse transform, where the inverse Celestial | |
| transform is also dependent upon the pixel coordinate z. | |
| This means that the output of the inverse Temporal transform has to be | |
| duplicated as an input to the Celestial transform's inverse. | |
| This is achieved by the use of the ``Mapping`` models in | |
| ``CoupledCompoundModel.inverse`` to create a multi-stage compound model | |
| which duplicates the output of the right hand side model. | |
| lon lat time | |
| │ │ │ | |
| │ │ ▼ | |
| │ │ ┌─────────┐ | |
| │ │ │Temporal'│ | |
| │ │ └──┬──┬───┘ | |
| │ │ z │ │ | |
| │ │ ┌─────┘ │ | |
| │ │ │ │ | |
| ▼ ▼ ▼ │ | |
| ┌──────────┐ │ | |
| │Celestial'│ │ | |
| └─┬───┬────┘ │ | |
| │ │ │ | |
| ▼ ▼ ▼ | |
| x y z | |
| """ | |
| _separable = False | |
| def __init__(self, op, left, right, name=None, shared_inputs=1): | |
| if op != "&": | |
| raise ValueError( | |
| f"The {self.__class__.__name__} class should only be used with the & operator." | |
| ) | |
| super().__init__(op, left, right, name=name) | |
| self.n_inputs = self.n_inputs - shared_inputs | |
| self.inputs = self.inputs[:-shared_inputs] | |
| self.shared_inputs = shared_inputs | |
| def _evaluate(self, *args, **kw): | |
| leftval = self.left(*(args[:self.left.n_inputs]), **kw) | |
| rightval = self.right(*(args[-self.right.n_inputs:]), **kw) | |
| return self._apply_operators_to_value_lists(leftval, rightval, **kw) | |
| @property | |
| def inverse(self): | |
| left_inverse = self.left.inverse | |
| right_inverse = self.right.inverse | |
| total_inputs = self.n_outputs | |
| n_left_only_inputs = total_inputs - self.shared_inputs | |
| # Pass through arguments to the left model unchanged while computing the right output | |
| mapping = list(range(n_left_only_inputs)) | |
| step1 = m.Mapping(mapping) & right_inverse | |
| # Now pass through the right outputs unchanged while also feeding them into the left model | |
| # This mapping duplicates the output of the right inverse to be fed | |
| # into the left and also out unmodified at the end of the transform | |
| inter_mapping = mapping + list(range(max(mapping) + 1, max(mapping) + 1 + right_inverse.n_outputs)) * 2 | |
| step2 = m.Mapping(inter_mapping) | (left_inverse & m.Mapping(list(range(right_inverse.n_outputs)))) | |
| return step1 | step2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment