Source code for astrophot.models.basis_psf

from typing import Union
import torch
import numpy as np

from .psf_model_object import PSFModel
from ..utils.decorators import ignore_numpy_warnings, combine_docstrings
from ..utils.interpolate import interp2d
from .. import config
from ..backend_obj import backend, ArrayLike
from ..errors import SpecificationConflict
from ..param import forward
from . import func

__all__ = ["PixelBasisPSF"]


[docs] @combine_docstrings class PixelBasisPSF(PSFModel): """point source model which uses multiple images as a basis for the PSF as its representation for point sources. Using bilinear interpolation it will shift the PSF within a pixel to accurately represent the center location of a point source. There is no functional form for this object type as any image can be supplied. Bilinear interpolation is very fast and accurate for smooth models, so it is possible to do the expensive interpolation before optimization and save time. :param weights: The weights of the basis set of images in units of flux. """ _model_type = "basis" _parameter_specs = {"weights": {"units": "unitless", "shape": (None,), "dynamic": True}} usable = True def __init__(self, *args, basis: Union[str, ArrayLike] = "zernike:3", **kwargs): """Initialize the PixelBasisPSF model with a basis set of images.""" super().__init__(*args, **kwargs) self.basis = basis @property def basis(self): """The basis set of images used to form the eigen point source.""" return self._basis @basis.setter def basis(self, value: Union[str, ArrayLike]): """Set the basis set of images. If value is None, the basis is initialized to an empty tensor.""" if value is None: raise SpecificationConflict( "PixelBasisPSF requires a basis set of images to be provided." ) elif isinstance(value, str) and value.startswith("zernike:"): self._basis = value else: # Transpose since pytorch uses (j, i) indexing when (i, j) is more natural for coordinates self._basis = backend.transpose( backend.as_array(value, dtype=config.DTYPE, device=config.DEVICE), 2, 1 )
[docs] @torch.no_grad() @ignore_numpy_warnings def initialize(self): super().initialize() if isinstance(self.basis, str) and self.basis.startswith("zernike:"): order = int(self.basis.split(":")[1]) N = int(max(self.window.shape)) N = N + 1 - N % 2 self.basis = func.zernike_basis(order, N) / self.target.pixel_area if not self.weights.initialized: w = np.zeros(self.basis.shape[0]) w[0] = 1.0 self.weights.value = w
[docs] @forward def brightness(self, x: ArrayLike, y: ArrayLike, weights: ArrayLike) -> ArrayLike: x, y = self.transform_coordinates(x, y) wB = backend.sum(weights[:, None, None] * self.basis, dim=0) u = self.target.upsample return interp2d(wB, x * u + wB.shape[0] // 2, y * u + wB.shape[1] // 2)