from typing import Tuple
from caskade import forward
from .base import Model
from ..image import PSFImage
from ..errors import InvalidTarget
from .mixins import SampleMixin, GradMixin
from ..backend_obj import backend, ArrayLike
__all__ = ("PSFModel",)
[docs]
class PSFModel(GradMixin, SampleMixin, Model):
"""Prototype point source (typically a star) model, to be subclassed
by other point source models which define specific behavior.
PSF models behave differently than component models. Their target image
must be a ``PSFImage`` object instead of a ``TargetImage`` object.
PSF models do not fit a free ``center`` parameter; their center is
always ``(0, 0)`` in pixel coordinates, matching the convention of a
``PSFImage``. A PSF model is never convolved with another PSF model.
:param center: Center of the PSF in pixel coordinates ``[x, y]``.
Fixed at ``(0, 0)`` by default and not included in the fit.
:param normalize_psf: When ``True`` (default) the sampled PSF is
normalised so that its total flux within the fitting window equals 1.
"""
_parameter_specs = {
"center": {
"units": "pix",
"value": (0.0, 0.0),
"shape": (2,),
"dynamic": False,
"description": "center of the PSF in pixel coordinates [x, y], fixed at (0,0)",
},
}
_model_type = "psf"
usable = False
# The sampled PSF will be normalized to a total flux of 1 within the window
normalize_psf = True
# Parameters which are treated specially by the model object and should not be updated directly when initializing
_options = ("normalize_psf",)
[docs]
def initialize(self):
pass
@property
def upsample(self):
return self.target.upsample
@property
def pad(self):
return self.target.pad
[docs]
@forward
def pixel_brightness(self, i, j):
"""Evaluate the model at the pixel coordinates defined by i and j (of
the target image). For a PSF model, this is the same as the brightness
since it is defined in pixel units."""
return self.brightness(*self.target.mypixel_to_targpixel(i, j))
def _prep_psf(self):
return None, 1, 0
# Fit loop functions
######################################################################
[docs]
@forward
def sample(
self,
i: ArrayLike,
j: ArrayLike,
*args,
**kwargs,
) -> PSFImage:
"""
Sample the PSF model on the pixel grid defined by i and j.
Depending on the model specification, this may involve supersampling for
higher precision, or it may just be a direct evaluation of the model at
the pixel centers. The output is the flux evaluated over the pixel grid
at native resolution (for the PSFImage associated with this model.)
:param i: 2D array of x-coordinates of pixel centers (or pre-upsampled
according to the ``sampling_mode``) in pixel units.
:param j: 2D array of y-coordinates of pixel centers (or pre-upsampled
according to the ``sampling_mode``) in pixel units.
:returns: 2D array (``Z``) of flux values at each pixel center, representing the
PSF model evaluated at those coordinates.
"""
Z = self.pixel_brightness(i, j)
Z = self._pixel_integrator(Z)
i, j = self._pixel_center_finder(i, j)
Z = self._adaptive_integrator(Z, i, j, 1, self.pixel_brightness)
return Z * self.target.pixel_area
@property
def target(self):
try:
return self._target
except AttributeError:
return None
@target.setter
def target(self, target):
if target is None:
self._target = None
elif not isinstance(target, PSFImage):
raise InvalidTarget(f"Target for PSFModel must be a PSFImage, not {type(target)}")
try:
del self._target # Remove old target if it exists
except AttributeError:
pass
self._target = target
@forward
def __call__(self) -> PSFImage:
working_image = self.target.model_image(self.window)
i, j = self._pixel_meshgridder(self.target, self.window, 0, 1)
working_image._data = self.sample(i, j)
if self.normalize_psf:
working_image._data = working_image._data / backend.sum(working_image._data)
return working_image