Source code for astrophot.fit.base

from typing import Sequence, Optional

import numpy as np
import torch
from scipy.optimize import minimize
from scipy.special import gammainc

from .. import AP_config


__all__ = ["BaseOptimizer"]


[docs] class BaseOptimizer(object): """ Base optimizer object that other optimizers inherit from. Ensures consistent signature for the classes. Parameters: model: an AstroPhot_Model object that will have its (unlocked) parameters optimized [AstroPhot_Model] initial_state: optional initialization for the parameters as a 1D tensor [tensor] max_iter: maximum allowed number of iterations [int] relative_tolerance: tolerance for counting success steps as: 0 < (Chi2^2 - Chi1^2)/Chi1^2 < tol [float] """ def __init__( self, model: "AstroPhot_Model", initial_state: Sequence = None, relative_tolerance: float = 1e-3, fit_window: Optional["Window"] = None, **kwargs, ) -> None: """ Initializes a new instance of the class. Args: model (object): An object representing the model. initial_state (Optional[Sequence]): The initial state of the model could be any tensor. If `None`, the model's default initial state will be used. relative_tolerance (float): The relative tolerance for the optimization. fit_parameters_identity (Optional[tuple]): a tuple of parameter identity strings which tell the LM optimizer which parameters of the model to fit. **kwargs (dict): Additional keyword arguments. Attributes: model (object): An object representing the model. verbose (int): The verbosity level. current_state (Tensor): The current state of the model. max_iter (int): The maximum number of iterations. iteration (int): The current iteration number. save_steps (Optional[str]): Save intermediate results to this path. relative_tolerance (float): The relative tolerance for the optimization. lambda_history (List[ndarray]): A list of the optimization steps. loss_history (List[float]): A list of the optimization losses. message (str): An informational message. """ self.model = model self.verbose = kwargs.get("verbose", 0) if fit_window is None: self.fit_window = self.model.window else: self.fit_window = fit_window & self.model.window if initial_state is None: self.model.initialize() initial_state = self.model.parameters.vector_representation() else: initial_state = torch.as_tensor( initial_state, dtype=AP_config.ap_dtype, device=AP_config.ap_device ) self.current_state = torch.as_tensor( initial_state, dtype=AP_config.ap_dtype, device=AP_config.ap_device ) if self.verbose > 1: AP_config.ap_logger.info(f"initial state: {self.current_state}") self.max_iter = kwargs.get("max_iter", 100 * len(initial_state)) self.iteration = 0 self.save_steps = kwargs.get("save_steps", None) self.relative_tolerance = relative_tolerance self.lambda_history = [] self.loss_history = [] self.message = ""
[docs] def fit(self) -> "BaseOptimizer": """ Raises: NotImplementedError: Error is raised if this method is not implemented in a subclass of BaseOptimizer. """ raise NotImplementedError("Please use a subclass of BaseOptimizer for optimization")
[docs] def step(self, current_state: torch.Tensor = None) -> None: """Args: current_state (torch.Tensor, optional): Current state of the model parameters. Defaults to None. Raises: NotImplementedError: Error is raised if this method is not implemented in a subclass of BaseOptimizer. """ raise NotImplementedError("Please use a subclass of BaseOptimizer for optimization")
[docs] def chi2min(self) -> float: """ Returns the minimum value of chi^2 loss in the loss history. Returns: float: Minimum value of chi^2 loss. """ return np.nanmin(self.loss_history)
[docs] def res(self) -> np.ndarray: """Returns the value of lambda (regularization strength) at which minimum chi^2 loss was achieved. Returns: ndarray which is the Value of lambda at which minimum chi^2 loss was achieved. """ N = np.isfinite(self.loss_history) if np.sum(N) == 0: AP_config.ap_logger.warning( "Getting optimizer res with no real loss history, using current state" ) return self.current_state.detach().cpu().numpy() return np.array(self.lambda_history)[N][np.argmin(np.array(self.loss_history)[N])]
[docs] def res_loss(self): N = np.isfinite(self.loss_history) return np.min(np.array(self.loss_history)[N])
[docs] @staticmethod def chi2contour(n_params: int, confidence: float = 0.682689492137) -> float: """ Calculates the chi^2 contour for the given number of parameters. Args: n_params (int): The number of parameters. confidence (float, optional): The confidence interval (default is 0.682689492137). Returns: float: The calculated chi^2 contour value. Raises: RuntimeError: If unable to compute the Chi^2 contour for the given number of parameters. """ def _f(x: float, nu: int) -> float: """Helper function for calculating chi^2 contour.""" return (gammainc(nu / 2, x / 2) - confidence) ** 2 for method in ["L-BFGS-B", "Powell", "Nelder-Mead"]: res = minimize(_f, x0=n_params, args=(n_params,), method=method, tol=1e-8) if res.success: return res.x[0] raise RuntimeError(f"Unable to compute Chi^2 contour for ndf: {ndf}")