Source code for astrophot.models.batch_model_object

from typing import Optional
from ..backend_obj import backend
from ..param import forward
from ..errors import SpecificationConflict
from .base import Model
from .mixins import SampleMixin, GradMixin
from .model_object import ComponentModel
from ..image import TargetImage, Window, TargetImageBatch, WindowBatch
from . import func
from .. import config


[docs] class BatchModel(GradMixin, SampleMixin, Model): """A batch of models that all share the same window/target. This can for example be used to model a crowded area of the sky with many overlapping sources, or to model a single object that is represented by many components (consider this a generalization of the Multi-gaussian expansion model). If you want to model the same object in multiple images, see the BatchSceneModel instead. **Note:** any model parameters that you wish to batch over must be set to dynamic=True. See [caskade hierarchical models](https://caskade.readthedocs.io/en/latest/notebooks/HierarchicalModels.html) for more details. """ usable = True _model_type = "batch" def __init__(self, *, model: ComponentModel = None, **kwargs): super().__init__(**kwargs) assert isinstance( model, ComponentModel ), "BatchModel must be initialized with a ComponentModel instance." self.hierarchical_link("model", model)
[docs] def initialize(self): self.model.initialize()
@property def target(self) -> Optional[TargetImage]: return self.model.target @target.setter def target(self, target: Optional[TargetImage]): pass @property def window(self) -> Optional[Window]: """The window defines a region on the sky in which this model will be optimized and evaluated. Two models with non-overlapping windows are in effect independent of each other. If there is another model with a window that spans both of them, then they are tenuously connected. If not provided, the model will assume a window equal to the target it is fitting. Note that in this case the window is not explicitly set to the target window, so if the model is moved to another target then the fitting window will also change. """ return self.model.window @window.setter def window(self, window): pass @property def mask(self): return self.model.mask @mask.setter def mask(self, mask): pass @forward def __call__(self, model_params=None, model_dims=None, **kwargs): psf, upsample, pad = self.model._prep_psf() working_image = self.target.model_image(self.window) I, J = self.model._pixel_meshgridder(self.target, self.window, pad, upsample) Z = backend.vmap( self.model.sample, in_dims=(None, None, None, None, None, model_dims), )( I, J, None, pad, upsample, model_params, ) Z = backend.sum(Z, dim=0) if psf is not None and not self.model.internal_psf: if isinstance(psf, Model): psf = psf()._data if psf.shape != (1, 1): # skip if identity PSF Z = func.convolve(Z, psf) Z = Z[pad : Z.shape[0] - pad, pad : Z.shape[1] - pad] Z = func.downsample(Z, upsample) working_image._data = Z return working_image
[docs] class BatchSceneModel(GradMixin, Model): """A single model as viewed in multiple images. This model is quite restrictive in its use, but can provide a significant speedup by vectorizing the evaluation of the model. Some key things to keep in mind: - All model parameters that you wish to batch over must be set to dynamic=True. See [caskade hierarchical models](https://caskade.readthedocs.io/en/latest/notebooks/HierarchicalModels.html) for more details. - You must use a TargetImageBatch as the target, meaning that all the images must be the same number of pixels. - You must use a WindowBatch as the window (or none to just use the full images), meaning that the windows must all be the same shape in pixels. - The model you provide must have a TargetImage target and Window window, and these must have the same number of pixels as the batched versions. - If the base model has a PSFModel as its PSF then you cannot override it (that's part of the model), otherwise the PSF from the TargetImageBatch will override the model PSF for all images. """ usable = True _model_type = "batch scene" def __init__(self, *, model: Model = None, **kwargs): super().__init__(**kwargs) if not isinstance(model.target, TargetImage): raise SpecificationConflict( f"BatchSceneModel can only be used with models that have a TargetImage as their target, not a {type(model.target).__name__}." ) self.hierarchical_link("model", model)
[docs] def initialize(self): self.model.initialize()
@property def target(self) -> Optional[TargetImageBatch]: return self._target @target.setter def target(self, target: TargetImageBatch): assert isinstance( target, TargetImageBatch ), "BatchSceneModel target must be a TargetImageBatch." self._target = target @property def window(self) -> WindowBatch: if self._window is None: return self._target.window return self._window @window.setter def window(self, window): assert window is None or isinstance( window, WindowBatch ), "BatchSceneModel window must be a WindowBatch." self._window = window @forward def __call__(self, model_params=None, model_dims=None, **kwargs): working_image = self.target.model_image(self.window) crtan = self.target.crtan shift = backend.as_array( self.window.origin_shifter(self.model.window), dtype=config.DTYPE, device=config.DEVICE ) crpix = self.target.crpix + shift CD = self.target.CD psf = self.target.psf_stack psf_batch = None if psf is None else 0 working_image._data = backend.vmap( lambda *args: self.model(*args)._data, in_dims=(0, 0, 0, psf_batch, model_dims) )(CD, crtan, crpix, psf, model_params) return working_image