Source code for astrophot.image.mixins.sip_mixin

from typing import Union, Optional, Tuple

from ..image_object import Image
from ..window import Window
from .. import func
from ... import config
from ...backend_obj import backend, ArrayLike
from ...utils.interpolate import interp2d
from ...param import forward


[docs] class SIPMixin: """A mixin class for SIP (Simple Image Polynomial) distortion model.""" expect_ctype = (("RA---TAN-SIP",), ("DEC--TAN-SIP",)) def __init__( self, *args, sipA: dict[Tuple[int, int], float] = {}, sipB: dict[Tuple[int, int], float] = {}, sipAP: dict[Tuple[int, int], float] = {}, sipBP: dict[Tuple[int, int], float] = {}, pixel_area_map: Optional[ArrayLike] = None, distortion_ij: Optional[ArrayLike] = None, distortion_IJ: Optional[ArrayLike] = None, filename: Optional[str] = None, **kwargs, ): super().__init__(*args, filename=filename, **kwargs) if filename is not None: return self.sipA = sipA self.sipB = sipB self.sipAP = sipAP self.sipBP = sipBP if len(self.sipAP) == 0 and len(self.sipA) > 0: self.compute_backward_sip_coefs() self.update_distortion_model( distortion_ij=distortion_ij, distortion_IJ=distortion_IJ, pixel_area_map=pixel_area_map )
[docs] @forward def pixel_to_plane( self, i: ArrayLike, j: ArrayLike, crtan: ArrayLike, CD: ArrayLike, ) -> Tuple[ArrayLike, ArrayLike]: di = interp2d(self.distortion_ij[0], i, j, padding_mode="border") dj = interp2d(self.distortion_ij[1], i, j, padding_mode="border") return func.pixel_to_plane_linear(i + di, j + dj, *self.crpix, CD, *crtan)
[docs] @forward def plane_to_pixel( self, x: ArrayLike, y: ArrayLike, crtan: ArrayLike, CD: ArrayLike, ) -> Tuple[ArrayLike, ArrayLike]: I, J = func.plane_to_pixel_linear(x, y, *self.crpix, CD, *crtan) dI = interp2d(self.distortion_IJ[0], I, J, padding_mode="border") dJ = interp2d(self.distortion_IJ[1], I, J, padding_mode="border") return I + dI, J + dJ
[docs] @forward def pixel_collecting_area(self, I_, J_, upsample=1): # CMOS pixels only sensitive in sub area, so scale the pixel collecting area return interp2d(self.pixel_area_map, I_, J_, padding_mode="border") / upsample**2
@property @forward def pixel_area_map(self): return self._pixel_area_map * self.pixel_area @property def A_ORDER(self) -> int: if self.sipA: return max(a + b for a, b in self.sipA) return 0 @property def B_ORDER(self) -> int: if self.sipB: return max(a + b for a, b in self.sipB) return 0
[docs] def compute_backward_sip_coefs(self): """ Credit: Shu Liu and Lei Hi, see here: https://github.com/Roman-Supernova-PIT/sfft/blob/master/sfft/utils/CupyWCSTransform.py Compute the backward transformation from (U, V) to (u, v) """ i, j = self.pixel_center_meshgrid() u, v = i - self.crpix[0], j - self.crpix[1] du, dv = func.sip_delta(u, v, self.sipA, self.sipB) U = (u + du).flatten() V = (v + dv).flatten() AP, BP = func.sip_backward_transform( u.flatten(), v.flatten(), U, V, self.A_ORDER, self.B_ORDER ) self.sipAP = dict( ((p, q), ap.item()) for (p, q), ap in zip(func.sip_coefs(self.A_ORDER), AP) ) self.sipBP = dict( ((p, q), bp.item()) for (p, q), bp in zip(func.sip_coefs(self.B_ORDER), BP) )
[docs] def update_distortion_model( self, distortion_ij: Optional[ArrayLike] = None, distortion_IJ: Optional[ArrayLike] = None, pixel_area_map: Optional[ArrayLike] = None, ): """ Update the pixel area map based on the current SIP coefficients. """ # Pixelized distortion model ############################################################# if distortion_ij is None or distortion_IJ is None: i, j = self.pixel_center_meshgrid() u, v = i - self.crpix[0], j - self.crpix[1] if distortion_ij is None: distortion_ij = backend.stack(func.sip_delta(u, v, self.sipA, self.sipB), dim=0) if distortion_IJ is None: # fixme maybe distortion_IJ = backend.stack(func.sip_delta(u, v, self.sipAP, self.sipBP), dim=0) self.distortion_ij = distortion_ij self.distortion_IJ = distortion_IJ # Pixel area map ############################################################# if pixel_area_map is not None: self._pixel_area_map = pixel_area_map / self.pixel_area return i, j = self.pixel_corner_meshgrid() x, y = self.pixel_to_plane(i, j) # Shoelace formula for pixel area # 1: [:-1, :-1] # 2: [:-1, 1:] # 3: [1:, 1:] # 4: [1:, :-1] A = 0.5 * ( x[:-1, :-1] * y[:-1, 1:] + x[:-1, 1:] * y[1:, 1:] + x[1:, 1:] * y[1:, :-1] + x[1:, :-1] * y[:-1, :-1] - ( x[:-1, 1:] * y[:-1, :-1] + x[1:, 1:] * y[:-1, 1:] + x[1:, :-1] * y[1:, 1:] + x[:-1, :-1] * y[1:, :-1] ) ) self._pixel_area_map = backend.abs(A) / self.pixel_area
[docs] def to(self, dtype=None, device=None): if dtype is None: dtype = config.DTYPE if device is None: device = config.DEVICE super().to(dtype=dtype, device=device) self._pixel_area_map = backend.to(self._pixel_area_map, dtype=dtype, device=device) self.distortion_ij = backend.to(self.distortion_ij, dtype=dtype, device=device) self.distortion_IJ = backend.to(self.distortion_IJ, dtype=dtype, device=device)
[docs] def copy_kwargs(self, **kwargs): kwargs = { "sipA": self.sipA, "sipB": self.sipB, "sipAP": self.sipAP, "sipBP": self.sipBP, "pixel_area_map": self.pixel_area_map, "distortion_ij": self.distortion_ij, "distortion_IJ": self.distortion_IJ, **kwargs, } return super().copy_kwargs(**kwargs)
[docs] def get_window(self, other: Union[Image, Window], indices=None, **kwargs): """Get a sub-region of the image as defined by an other image on the sky.""" if indices is None: indices = self.get_indices(other if isinstance(other, Window) else other.window) return super().get_window( other, pixel_area_map=self.pixel_area_map[indices], distortion_ij=self.distortion_ij[:, indices[0], indices[1]], distortion_IJ=self.distortion_IJ[:, indices[0], indices[1]], indices=indices, **kwargs, )
[docs] def fits_info(self): info = super().fits_info() info["CTYPE1"] = "RA---TAN-SIP" info["CTYPE2"] = "DEC--TAN-SIP" a_order = 0 for a, b in self.sipA: info[f"A_{a}_{b}"] = self.sipA[(a, b)] a_order = max(a_order, a + b) info["A_ORDER"] = a_order b_order = 0 for a, b in self.sipB: info[f"B_{a}_{b}"] = self.sipB[(a, b)] b_order = max(b_order, a + b) info["B_ORDER"] = b_order ap_order = 0 for a, b in self.sipAP: info[f"AP_{a}_{b}"] = self.sipAP[(a, b)] ap_order = max(ap_order, a + b) info["AP_ORDER"] = ap_order bp_order = 0 for a, b in self.sipBP: info[f"BP_{a}_{b}"] = self.sipBP[(a, b)] bp_order = max(bp_order, a + b) info["BP_ORDER"] = bp_order return info
[docs] def load(self, filename: str, hduext: int = 0): hdulist = super().load(filename, hduext=hduext) self.sipA = {} if "A_ORDER" in hdulist[hduext].header: a_order = hdulist[hduext].header["A_ORDER"] for i in range(a_order + 1): for j in range(a_order + 1 - i): key = (i, j) if f"A_{i}_{j}" in hdulist[hduext].header: self.sipA[key] = hdulist[hduext].header[f"A_{i}_{j}"] self.sipB = {} if "B_ORDER" in hdulist[hduext].header: b_order = hdulist[hduext].header["B_ORDER"] for i in range(b_order + 1): for j in range(b_order + 1 - i): key = (i, j) if f"B_{i}_{j}" in hdulist[hduext].header: self.sipB[key] = hdulist[hduext].header[f"B_{i}_{j}"] self.sipAP = {} if "AP_ORDER" in hdulist[hduext].header: ap_order = hdulist[hduext].header["AP_ORDER"] for i in range(ap_order + 1): for j in range(ap_order + 1 - i): key = (i, j) if f"AP_{i}_{j}" in hdulist[hduext].header: self.sipAP[key] = hdulist[hduext].header[f"AP_{i}_{j}"] self.sipBP = {} if "BP_ORDER" in hdulist[hduext].header: bp_order = hdulist[hduext].header["BP_ORDER"] for i in range(bp_order + 1): for j in range(bp_order + 1 - i): key = (i, j) if f"BP_{i}_{j}" in hdulist[hduext].header: self.sipBP[key] = hdulist[hduext].header[f"BP_{i}_{j}"] self.update_distortion_model() return hdulist