from typing import Optional, Sequence, Union
import torch
import numpy as np
from caskade import forward
from .base import Model
from ..image import (
Image,
TargetImage,
TargetImageList,
ModelImage,
ModelImageList,
ImageList,
Window,
WindowList,
JacobianImage,
JacobianImageList,
)
from .. import config
from ..backend_obj import backend, ArrayLike
from ..utils.decorators import ignore_numpy_warnings
from ..errors import InvalidTarget, InvalidWindow
__all__ = ["GroupModel"]
[docs]
class GroupModel(Model):
"""Model object which represents a list of other models. For each
general AstroPhot model method, this calls all the appropriate
models from its list and combines their output into a single
summed model. This class should be used when describing any
system more complex than makes sense to represent with a single
light distribution.
Args:
name (str): unique name for the full group model
target (Target_Image): the target image that this group model is trying to fit to
models (Optional[Sequence[AstroPhot_Model]]): list of AstroPhot_Model objects which will combine for the group model
locked (bool): if the whole group of models should be locked
"""
_model_type = "group"
usable = True
def __init__(
self,
*,
models: Optional[Sequence[Model]] = None,
**kwargs,
):
super().__init__(**kwargs)
for model in models:
if not isinstance(model, Model):
raise TypeError(f"Expected a Model instance in 'models', got {type(model)}")
self.models = models
self._update_window()
def _update_window(self):
"""Makes a new window object which encloses all the windows of the
sub models in this group model object.
"""
if isinstance(self.target, ImageList): # WindowList if target is a TargetImageList
new_window = list(target.window.copy() for target in self.target)
n_windows = [0] * len(self.target.images)
for model in self.models:
if isinstance(model.target, ImageList):
for target, window in zip(model.target, model.window):
index = self.target.index(target)
if n_windows[index] == 0:
new_window[index] &= window
else:
new_window[index] |= window
n_windows[index] += 1
elif isinstance(model.target, TargetImage):
index = self.target.index(model.target)
if n_windows[index] == 0:
new_window[index] &= model.window
else:
new_window[index] |= model.window
n_windows[index] += 1
else:
raise NotImplementedError(
f"Group_Model cannot construct a window for itself using {type(model.target)} object. Must be a Target_Image"
)
new_window = WindowList(new_window)
for i, n in enumerate(n_windows):
if n == 0:
config.logger.warning(
f"Model {self.name} has no sub models in target '{self.target.images[i].name}', this may cause issues with fitting."
)
else:
new_window = None
for model in self.models:
if new_window is None:
new_window = model.window.copy()
else:
new_window |= model.window
self.window = new_window
[docs]
@ignore_numpy_warnings
def initialize(self):
"""
Initialize each model in this group. Does this by iteratively initializing a model then subtracting it from a copy of the target.
"""
for model in self.models:
config.logger.info(f"Initializing model {model.name}")
model.initialize()
[docs]
def match_window(self, image: Union[Image, ImageList], window: Window, model: Model) -> Window:
if isinstance(image, ImageList) and isinstance(model.target, ImageList):
indices = image.match_indices(model.target)
if len(indices) == 0:
raise IndexError
use_window = WindowList(windows=list(image.images[i].window for i in indices))
elif isinstance(image, ImageList) and isinstance(model.target, Image):
try:
image.index(model.target)
except ValueError:
raise IndexError
use_window = model.window
elif isinstance(image, Image) and isinstance(model.target, ImageList):
try:
i = model.target.index(image)
except ValueError:
raise IndexError
use_window = model.window[i]
elif isinstance(image, Image) and isinstance(model.target, Image):
if image.identity != model.target.identity:
raise IndexError
use_window = window
else:
raise NotImplementedError(
f"Group_Model cannot sample with {type(image)} and {type(model.target)}"
)
return use_window
def _ensure_vmap_compatible(
self, image: Union[Image, ImageList], other: Union[Image, ImageList]
):
if isinstance(image, ImageList):
for img in image.images:
self._ensure_vmap_compatible(img, other)
return
if isinstance(other, ImageList):
for img in other.images:
self._ensure_vmap_compatible(image, img)
return
if image.identity == other.identity:
image += backend.zeros_like(other._data[0, 0])
[docs]
@forward
def sample(
self,
_CD: Optional[ArrayLike] = None,
_crtan: Optional[ArrayLike] = None,
_crpix: Optional[ArrayLike] = None,
_psf: Optional[ArrayLike] = None,
) -> Union[ModelImage, ModelImageList]:
"""Sample the group model on an image. Produces the flux values for
each pixel associated with the models in this group. Each
model is called individually and the results are added
together in one larger image.
**Args:**
- `image` (Optional[ModelImage]): Image to sample on, overrides the windows for each sub model, they will all be evaluated over this entire image. If left as none then each sub model will be evaluated in its window.
"""
image = self.target.model_image(self.window)
for model in self.models:
model_image = model(_CD=_CD, _crtan=_crtan, _crpix=_crpix, _psf=_psf)
self._ensure_vmap_compatible(image, model_image)
image += model_image
return image
[docs]
def jacobian(
self,
pass_jacobian: Optional[Union[JacobianImage, JacobianImageList]] = None,
params=None,
) -> JacobianImage:
"""Compute the jacobian for this model. Done by first constructing a
full jacobian (Npixels * Nparameters) of zeros then call the
jacobian method of each sub model and add it in to the total.
**Args:**
- `pass_jacobian` (Optional[JacobianImage]): A Jacobian image pre-constructed to be passed along instead of constructing new Jacobians
- `window` (Optional[Window]): A window within which to evaluate the jacobian. If not provided, the model's window will be used.
- `params` (Optional[Sequence[Param]]): Parameters to use for the jacobian. If not provided, the model's parameters will be used.
"""
if params is not None:
self.set_values(params)
if pass_jacobian is None:
jac_img = self.target[self.window].jacobian_image(
parameters=self.build_params_array_identities()
)
else:
jac_img = pass_jacobian
for model in self.models:
jac_img = model.jacobian(pass_jacobian=jac_img)
return jac_img
def __iter__(self):
return (mod for mod in self.models)
@property
def target(self) -> Optional[Union[TargetImage, TargetImageList]]:
try:
return self._target
except AttributeError:
return None
@target.setter
def target(self, tar: Optional[Union[TargetImage, TargetImageList]]):
if not (tar is None or isinstance(tar, (TargetImage, TargetImageList))):
raise InvalidTarget("Group_Model target must be a Target_Image instance.")
try:
del self._target # Remove old target if it exists
except AttributeError:
pass
self._target = tar
@property
def window(self) -> Optional[Union[Window, WindowList]]:
"""The window defines a region on the sky in which this model will be
optimized and typically evaluated. Two models with
non-overlapping windows are in effect independent of each
other. If there is another model with a window that spans both
of them, then they are tenuously connected.
If not provided, the model will assume a window equal to the
target it is fitting. Note that in this case the window is not
explicitly set to the target window, so if the model is moved
to another target then the fitting window will also change.
"""
if self._window is None:
if self.target is None:
raise ValueError(
"This model has no target or window, these must be provided by the user"
)
return self.target.window
return self._window
@window.setter
def window(self, window):
if window is None:
self._window = None
elif isinstance(window, (Window, WindowList)):
self._window = window
elif len(window) in [2, 4]:
self._window = Window(window, image=self.target)
else:
raise InvalidWindow(f"Unrecognized window format: {str(window)}")
[docs]
def segmentation_map(self) -> ArrayLike:
"""Generate a segmentation map for this group model. Each pixel in the
segmentation map is assigned an integer value corresponding to the index
of the sub-model that corresponds to that pixel. The pixels are assigned
based on "relative importance", meaning that for each pixel, the
sub-model which contributes the largest fraction of its own total flux to that
pixel is assigned to it.
Returns:
ArrayLike: Segmentation map with the same shape as the target image as windowed by the group model window.
"""
subtarget = self.target[self.window]
if isinstance(subtarget, ImageList):
raise NotImplementedError(
"Segmentation maps are not currently supported for ImageList targets. Please apply one target at a time."
)
else:
seg_map = backend.zeros_like(subtarget._data, dtype=backend.int32) - 1
max_flux_frac = (
0.0 * backend.ones_like(subtarget._data) / np.prod(subtarget._data.shape)
)
for idx, model in enumerate(self.models):
model_image = model()
model_flux_frac = backend.abs(model_image._data) / backend.sum(
backend.abs(model_image._data)
)
indices = subtarget.get_indices(model.window)
model_flux_frac_full = backend.zeros_like(subtarget._data)
model_flux_frac_full = backend.fill_at_indices(
model_flux_frac_full, indices, model_flux_frac
)
update_mask = model_flux_frac_full >= max_flux_frac
seg_map = backend.where(update_mask, idx, seg_map)
max_flux_frac = backend.where(update_mask, model_flux_frac_full, max_flux_frac)
return seg_map.T
[docs]
def deblend(self) -> Sequence[TargetImage]:
"""Generate deblended images for each sub-model in this group model.
Each deblended image contains for each pixel, the fraction of the total
flux at that pixel which is contributed by that sub-model.
Returns:
Sequence[TargetImage]: List of deblended TargetImage objects for each sub-model.
"""
deblended_images = []
subtarget = self.target[self.window]
full_model = self()
if isinstance(subtarget, ImageList):
raise NotImplementedError(
"Deblending is not currently supported for ImageList targets. Please apply one target at a time."
)
else:
for model in self.models:
model_image = model()
subfull_model = full_model[model.window]
subsubtarget = subtarget[model.window].copy(
name=f"deblend_{model.name}_{subtarget.name}"
)
deblend_data = subsubtarget.data * model_image.data / subfull_model.data
deblend_variance = subsubtarget.variance * model_image.data / subfull_model.data
subsubtarget.data = deblend_data
subsubtarget.variance = deblend_variance
deblended_images.append(subsubtarget)
return deblended_images