Source code for astrophot.models.base

from typing import Optional, Union
from copy import deepcopy

import numpy as np

from caskade import Param as CParam
from ..param import Module, forward, Param
from ..utils.decorators import classproperty
from ..image import Window, ModelImage, ModelImageList
from ..errors import UnrecognizedModel, InvalidWindow
from .. import config
from ..backend_obj import backend, ArrayLike
from . import func

__all__ = ("Model",)


######################################################################
[docs] class Model(Module): """Base class for all AstroPhot models.""" _model_type = "model" _parameter_specs = {} # Softening length used for numerical stability and/or integration stability to avoid discontinuities (near R=0) softening = 1e-3 # arcsec _options = ("softening",) usable = False def __new__(cls, *, filename=None, model_type=None, **kwargs): if filename is not None: state = Model.load(filename) MODELS = Model.List_Models() for M in MODELS: if M.model_type == state["model_type"]: return super(Model, cls).__new__(M) else: raise UnrecognizedModel(f"Unknown AstroPhot model type: {state['model_type']}") elif model_type is not None: MODELS = Model.List_Models() # all_subclasses(Model) for M in MODELS: if M.model_type == model_type: return super(Model, cls).__new__(M) else: raise UnrecognizedModel(f"Unknown AstroPhot model type: {model_type}") return super().__new__(cls) def __init__(self, *, name=None, target=None, window=None, mask=None, filename=None, **kwargs): super().__init__(name=name) self.target = target self.window = window self.mask = mask # Set any user defined options for the model for kwarg in list(kwargs.keys()): if kwarg in self.options: setattr(self, kwarg, kwargs.pop(kwarg)) # Create Param objects for this Module parameter_specs = self.build_parameter_specs(kwargs, self.parameter_specs) for key in parameter_specs: param = Param(key, **parameter_specs[key], dtype=config.DTYPE, device=config.DEVICE) setattr(self, key, param) self.saveattrs.update(self.options) kwargs.pop("model_type", None) # model_type is set by __new__ if len(kwargs) > 0: raise TypeError( f"Unrecognized keyword arguments for {self.__class__.__name__}: {', '.join(kwargs.keys())}" ) @classproperty def model_type(cls) -> str: collected = [] for subcls in cls.mro(): if subcls is object: continue mt = subcls.__dict__.get("_model_type", None) if mt: collected.append(mt) return " ".join(collected) @classproperty def options(cls) -> set: options = set() for subcls in cls.mro(): if subcls is object: continue options.update(subcls.__dict__.get("_options", [])) return options @classproperty def parameter_specs(cls) -> dict: """Collects all parameter specifications from the class hierarchy.""" specs = {} for subcls in reversed(cls.mro()): if subcls is object: continue specs.update(getattr(subcls, "_parameter_specs", {})) return specs
[docs] def build_parameter_specs(self, kwargs, parameter_specs) -> dict: parameter_specs = deepcopy(parameter_specs) for p in list(kwargs.keys()): if p not in parameter_specs: continue if isinstance(kwargs[p], dict): parameter_specs[p].update(kwargs.pop(p)) else: parameter_specs[p]["value"] = kwargs.pop(p) if isinstance(parameter_specs[p].get("value", None), CParam) or callable( parameter_specs[p].get("value", None) ): parameter_specs[p]["dynamic"] = False return parameter_specs
[docs] @forward def gaussian_log_likelihood(self) -> ArrayLike: """ Compute the negative log likelihood of the model wrt the target image in the appropriate window. """ model = self().flatten("data") data = self.target[self.window] weight = data.flatten("weight") mask = data.flatten("mask") data = data.flatten("data") nll = 0.5 * backend.sum((data - model) ** 2 * weight * (~mask)) return -nll
[docs] @forward def poisson_log_likelihood(self) -> ArrayLike: """ Compute the negative log likelihood of the model wrt the target image in the appropriate window. """ model = self().flatten("data") data = self.target[self.window] mask = data.flatten("mask") data = data.flatten("data") nll = backend.sum( (model - data * backend.log(model + 1e-10) + backend.lgamma(data + 1)) * (~mask) ) return -nll
[docs] def hessian(self, likelihood="gaussian"): if likelihood == "gaussian": return backend.hessian(self.gaussian_log_likelihood)(self.get_values()) elif likelihood == "poisson": return backend.hessian(self.poisson_log_likelihood)(self.get_values()) else: raise ValueError(f"Unknown likelihood type: {likelihood}")
[docs] def total_flux(self) -> ArrayLike: F = self() return backend.sum(F.flatten("data"))
[docs] def total_flux_uncertainty(self) -> ArrayLike: jac = self.jacobian().flatten("data") dF = backend.sum(jac, dim=0) # VJP for sum(total_flux) current_uncertainty = self.build_params_array_uncertainty() return backend.sqrt(backend.sum((dF * current_uncertainty) ** 2))
[docs] def total_magnitude(self) -> ArrayLike: """Compute the total magnitude of the model in the given window.""" F = self.total_flux() return -2.5 * backend.log10(F) + self.target.zeropoint
[docs] def total_magnitude_uncertainty(self) -> ArrayLike: """Compute the uncertainty in the total magnitude of the model in the given window.""" F = self.total_flux() dF = self.total_flux_uncertainty() return 2.5 * (dF / F) / np.log(10)
@property def window(self) -> Optional[Window]: """The window defines a region on the sky in which this model will be optimized and typically evaluated. Two models with non-overlapping windows are in effect independent of each other. If there is another model with a window that spans both of them, then they are tenuously connected. If not provided, the model will assume a window equal to the target it is fitting. Note that in this case the window is not explicitly set to the target window, so if the model is moved to another target then the fitting window will also change. """ if self._window is None: if self.target is None: raise ValueError( "This model has no target or window, these must be provided by the user" ) return self.target.window return self._window @window.setter def window(self, window): if window is None: self._window = None elif isinstance(window, Window): self._window = window elif len(window) in [2, 4]: self._window = Window(window, image=self.target) else: raise InvalidWindow(f"Unrecognized window format: {str(window)}")
[docs] @classmethod def List_Models(cls, usable: Optional[bool] = None, types: bool = False) -> set: MODELS = func.all_subclasses(cls) result = set() for model in MODELS: if not (model.__dict__.get("usable", False) is usable or usable is None): continue if types: result.add(model.model_type) else: result.add(model) return result
[docs] @forward def radius_metric(self, x, y): return backend.sqrt(x**2 + y**2 + self.softening**2)
[docs] @forward def angular_metric(self, x, y): return backend.arctan2(y, backend.where(x < 0, x - self.softening, x + self.softening))
[docs] def to(self, dtype=None, device=None): if dtype is None: dtype = config.DTYPE if device is None: device = config.DEVICE super().to(dtype=dtype, device=device)
@forward def __call__( self, *args, **kwargs, ) -> Union[ModelImage, ModelImageList]: return self.sample(*args, **kwargs)