Source code for astrophot.fit.base
from typing import Sequence, Optional
import numpy as np
from .. import config
from ..backend_obj import backend, ArrayLike
from ..models import Model
__all__ = ("BaseOptimizer",)
[docs]
class BaseOptimizer:
"""
Base optimizer object that other optimizers inherit from. Ensures consistent signature for the classes.
:param model: an AstroPhot_Model object that will have its (unlocked) parameters optimized [AstroPhot_Model]
:param initial_state: optional initialization for the parameters as a 1D Array [Array]
:param relative_tolerance: tolerance for counting success steps as: $0 < (\\chi_2^2 - \\chi_1^2)/\\chi_1^2 < \\text{tol}$ [float]
:param verbose: verbosity level for the optimizer [int]
:param max_iter: maximum allowed number of iterations [int]
:param save_steps: optional string for path to save the model at each step (fitter dependent), e.g. "model_step_{step}.hdf5" [str]
:param fit_valid: whether to fit while forcing parameters into valid range, or allow any value for each parameter. Default True [bool]
"""
def __init__(
self,
model: Model,
initial_state: Sequence = None,
relative_tolerance: float = 1e-3,
verbose: int = 1,
max_iter: int = None,
save_steps: Optional[str] = None,
fit_valid: bool = True,
) -> None:
self.model = model
self.verbose = verbose
if initial_state is None:
self.current_state = model.get_values()
else:
self.current_state = backend.as_array(
initial_state, dtype=config.DTYPE, device=config.DEVICE
)
self.max_iter = max_iter if max_iter is not None else 100 * len(self.current_state)
self.iteration = 0
self.save_steps = save_steps
self.fit_valid = fit_valid
self.relative_tolerance = relative_tolerance
self.lambda_history = []
self.loss_history = []
self.message = ""
[docs]
def fit(self) -> "BaseOptimizer":
raise NotImplementedError("Please use a subclass of BaseOptimizer for optimization")
[docs]
def step(self, current_state: ArrayLike = None) -> None:
raise NotImplementedError("Please use a subclass of BaseOptimizer for optimization")
[docs]
def chi2min(self) -> float:
"""
Returns the minimum value of chi^2 loss in the loss history.
"""
return np.nanmin(self.loss_history)
[docs]
def res(self) -> np.ndarray:
"""Returns the value of lambda (state parameters) at which minimum loss was achieved."""
N = np.isfinite(self.loss_history)
if np.sum(N) == 0:
config.logger.warning(
"Getting optimizer res with no real loss history, using current state"
)
return backend.to_numpy(self.current_state)
return np.array(self.lambda_history)[N][np.argmin(np.array(self.loss_history)[N])]
[docs]
def res_loss(self):
"""returns the minimum value from the loss history."""
N = np.isfinite(self.loss_history)
return np.min(np.array(self.loss_history)[N])