Source code for astrophot.fit.gradient

# Traditional gradient descent with Adam
from time import time
from typing import Sequence
from caskade import ValidContext

from .base import BaseOptimizer
from .. import config
from ..backend_obj import backend, ArrayLike
from ..models import Model
from ..errors import OptimizeStopFail, OptimizeStopSuccess
from . import func

__all__ = ["Grad"]


[docs] class Slalom(BaseOptimizer): """Slalom optimizer for Model objects. Slalom is a gradient descent optimization algorithm that uses a few evaluations along the direction of the gradient to find the optimal step size. This is done by assuming that the posterior density is a parabola and then finding the minimum. The optimizer quickly finds the minimum of the posterior density along the gradient direction, then updates the gradient at the new position and repeats. This continues until it reaches a set of 5 steps which collectively improve the posterior density by an amount smaller than the `relative_tolerance` threshold, indicating that convergence has been achieved. Note that this convergence criteria is not a guarantee, simply a heuristic. The default tolerance was such that the optimizer will substantially improve from the starting point, and do so quickly, but may not reach all the way to the minimum of the posterior density. Like other gradient descent algorithms, Slalom slows down considerably when trying to achieve very high precision. :param S: The initial step size for the Slalom optimizer. Defaults to 1e-4. :type S: float, optional :param likelihood: The likelihood function to use for the optimization. Defaults to "gaussian". :type likelihood: str, optional :param report_freq: Frequency of reporting the optimization progress. Defaults to 10 steps. :type report_freq: int, optional :param relative_tolerance: The relative tolerance for convergence. Defaults to 1e-4. :type relative_tolerance: float, optional :param momentum: The momentum factor for the Slalom optimizer. Defaults to 0.5. :type momentum: float, optional :param max_iter: The maximum number of iterations for the optimizer. Defaults to 1000. :type max_iter: int, optional """ def __init__( self, model: Model, initial_state: Sequence = None, S=1e-4, likelihood: str = "gaussian", report_freq: int = 10, relative_tolerance: float = 1e-4, momentum: float = 0.9, max_iter: int = 1000, **kwargs, ) -> None: """Initialize the Slalom optimizer.""" super().__init__( model, initial_state, relative_tolerance=relative_tolerance, max_iter=max_iter, **kwargs ) self.likelihood = likelihood self.S = S self.report_freq = report_freq assert 0 <= momentum < 1, "Momentum must be in the range [0, 1)." self.momentum = momentum
[docs] def density(self, state: ArrayLike) -> ArrayLike: """Calculate the density of the model at the given state. Based on ``self.likelihood``, will be either the Gaussian or Poisson negative log likelihood.""" if self.likelihood == "gaussian": return -self.model.gaussian_log_likelihood(state) elif self.likelihood == "poisson": return -self.model.poisson_log_likelihood(state) else: raise ValueError(f"Unknown likelihood type: {self.likelihood}")
[docs] def fit(self) -> BaseOptimizer: """Perform the Slalom optimization.""" grad_func = backend.grad(self.density) momentum = backend.zeros_like(self.current_state) self.S_history = [self.S] self.loss_history = [self.density(self.current_state).item()] self.lambda_history = [backend.to_numpy(self.current_state)] self.start_fit = time() for i in range(self.max_iter): try: # Perform the Slalom step vstate = self.model.to_valid(self.current_state) with ValidContext(self.model): self.S, loss, grad = func.slalom_step( self.density, grad_func, vstate, m=momentum, S=self.S ) self.current_state = self.model.from_valid( vstate - self.S * (grad + momentum) / backend.linalg.norm(grad + momentum) ) momentum = self.momentum * momentum + (1 - self.momentum) * grad except OptimizeStopSuccess as e: self.message = self.message + str(e) break except OptimizeStopFail as e: if backend.allclose(momentum, backend.zeros_like(momentum)): self.message = self.message + str(e) break momentum = backend.zeros_like(self.current_state) continue # Log the loss self.S_history.append(self.S) self.loss_history.append(loss) self.lambda_history.append(backend.to_numpy(self.current_state)) if self.verbose > 0 and (i % int(self.report_freq) == 0 or i == self.max_iter - 1): config.logger.info( f"iter: {i}, step size: {self.S:.6e}, posterior density: {loss:.6e}" ) if len(self.loss_history) >= 5: relative_loss = (self.loss_history[-5] - self.loss_history[-1]) / self.loss_history[ -1 ] if relative_loss < self.relative_tolerance: self.message = self.message + " success" break else: self.message = self.message + " fail. max iteration reached" # Set the model parameters to the best values from the fit self.model.set_values( backend.as_array(self.res(), dtype=config.DTYPE, device=config.DEVICE) ) if self.verbose > 0: config.logger.info( f"Slalom Fitting complete in {time() - self.start_fit} sec with message: {self.message}" ) return self