# Traditional gradient descent with Adam
from time import time
from typing import Sequence
from caskade import ValidContext
import torch
import numpy as np
from .base import BaseOptimizer
from .. import config
from ..backend_obj import backend, ArrayLike
from ..models import Model
from ..errors import OptimizeStopFail, OptimizeStopSuccess
from . import func
from ..utils.decorators import combine_docstrings
__all__ = ["Grad"]
[docs]
@combine_docstrings
class Grad(BaseOptimizer):
"""A gradient descent optimization wrapper for AstroPhot Model objects.
The default method is "NAdam", a variant of the Adam optimization algorithm.
This optimizer uses a combination of gradient descent and Nesterov momentum for faster convergence.
The optimizer is instantiated with a set of initial parameters and optimization options provided by the user.
The `fit` method performs the optimization, taking a series of gradient steps until a stopping criteria is met.
**Args:**
- `likelihood` (str, optional): The likelihood function to use for the optimization. Defaults to "gaussian".
- `method` (str, optional): the optimization method to use for the update step. Defaults to "NAdam".
- `optim_kwargs` (dict, optional): a dictionary of keyword arguments to pass to the pytorch optimizer.
- `patience` (int, optional): number of steps with no improvement before stopping the optimization. Defaults to 10.
- `report_freq` (int, optional): frequency of reporting the optimization progress. Defaults to 10 steps.
"""
def __init__(
self,
model: Model,
initial_state: Sequence = None,
likelihood="gaussian",
method="NAdam",
optim_kwargs={},
patience: int = 10,
report_freq=10,
**kwargs,
) -> None:
super().__init__(model, initial_state, **kwargs)
self.likelihood = likelihood
# set parameters from the user
self.patience = patience
self.method = method
self.optim_kwargs = optim_kwargs
self.report_freq = report_freq
# Default learning rate if none given. Equal to 1 / sqrt(parames)
if "lr" not in self.optim_kwargs:
self.optim_kwargs["lr"] = 0.1 / (len(self.current_state) ** (0.5))
# Instantiates the appropriate pytorch optimizer with the initial state and user provided kwargs
self.current_state.requires_grad = True
self.optimizer = getattr(torch.optim, self.method)(
(self.current_state,), **self.optim_kwargs
)
[docs]
def density(self, state: torch.Tensor) -> torch.Tensor:
"""
Returns the density of the model at the given state vector. This is used
to calculate the likelihood of the model at the given state. Based on
``self.likelihood``, will be either the Gaussian or Poisson negative log
likelihood.
"""
if self.likelihood == "gaussian":
return -self.model.gaussian_log_likelihood(state)
elif self.likelihood == "poisson":
return -self.model.poisson_log_likelihood(state)
else:
raise ValueError(f"Unknown likelihood type: {self.likelihood}")
[docs]
def step(self) -> None:
"""Take a single gradient step.
Computes the loss function of the model,
computes the gradient of the parameters using automatic differentiation,
and takes a step with the PyTorch optimizer.
"""
self.iteration += 1
self.optimizer.zero_grad()
self.current_state.requires_grad = True
loss = self.density(self.current_state)
loss.backward()
self.loss_history.append(backend.to_numpy(loss))
self.lambda_history.append(np.copy(backend.to_numpy(self.current_state)))
if (
self.iteration % int(self.max_iter / self.report_freq) == 0
) or self.iteration == self.max_iter:
if self.verbose > 0:
config.logger.info(f"iter: {self.iteration}, posterior density: {loss.item():.6e}")
if self.verbose > 1:
config.logger.info(f"gradient: {self.current_state.grad}")
self.optimizer.step()
[docs]
def fit(self) -> BaseOptimizer:
"""
Perform an iterative fit of the model parameters using the specified optimizer.
The fit procedure continues until a stopping criteria is met,
such as the maximum number of iterations being reached,
or no improvement being made after a specified number of iterations.
"""
start_fit = time()
try:
while True:
self.step()
if self.iteration >= self.max_iter:
self.message = self.message + " fail max iteration reached"
break
if (
self.patience is not None
and (len(self.loss_history) - np.argmin(self.loss_history)) > self.patience
):
self.message = self.message + " fail no improvement"
break
L = np.sort(self.loss_history)
if len(L) >= 5 and 0 < (L[4] - L[0]) / L[0] < self.relative_tolerance:
self.message = self.message + " success"
break
except KeyboardInterrupt:
self.message = self.message + " fail interrupted"
# Set the model parameters to the best values from the fit and clear any previous model sampling
self.model.set_values(torch.tensor(self.res(), dtype=config.DTYPE, device=config.DEVICE))
if self.verbose > 1:
config.logger.info(
f"Grad Fitting complete in {time() - start_fit} sec with message: {self.message}"
)
return self
[docs]
class Slalom(BaseOptimizer):
"""Slalom optimizer for Model objects.
Slalom is a gradient descent optimization algorithm that uses a few
evaluations along the direction of the gradient to find the optimal step
size. This is done by assuming that the posterior density is a parabola and
then finding the minimum.
The optimizer quickly finds the minimum of the posterior density along the
gradient direction, then updates the gradient at the new position and
repeats. This continues until it reaches a set of 5 steps which collectively
improve the posterior density by an amount smaller than the
`relative_tolerance` threshold, indicating that convergence has been
achieved. Note that this convergence criteria is not a guarantee, simply a
heuristic. The default tolerance was such that the optimizer will
substantially improve from the starting point, and do so quickly, but may
not reach all the way to the minimum of the posterior density. Like other
gradient descent algorithms, Slalom slows down considerably when trying to
achieve very high precision.
**Args:**
- `S` (float, optional): The initial step size for the Slalom optimizer. Defaults to 1e-4.
- `likelihood` (str, optional): The likelihood function to use for the optimization. Defaults to "gaussian".
- `report_freq` (int, optional): Frequency of reporting the optimization progress. Defaults to 10 steps.
- `relative_tolerance` (float, optional): The relative tolerance for convergence. Defaults to 1e-4.
- `momentum` (float, optional): The momentum factor for the Slalom optimizer. Defaults to 0.5.
- `max_iter` (int, optional): The maximum number of iterations for the optimizer. Defaults to 1000.
"""
def __init__(
self,
model: Model,
initial_state: Sequence = None,
S=1e-4,
likelihood: str = "gaussian",
report_freq: int = 10,
relative_tolerance: float = 1e-4,
momentum: float = 0.5,
max_iter: int = 1000,
**kwargs,
) -> None:
"""Initialize the Slalom optimizer."""
super().__init__(
model, initial_state, relative_tolerance=relative_tolerance, max_iter=max_iter, **kwargs
)
self.likelihood = likelihood
self.S = S
self.report_freq = report_freq
self.momentum = momentum
[docs]
def density(self, state: ArrayLike) -> ArrayLike:
"""Calculate the density of the model at the given state. Based on
``self.likelihood``, will be either the Gaussian or Poisson negative log
likelihood."""
if self.likelihood == "gaussian":
return -self.model.gaussian_log_likelihood(state)
elif self.likelihood == "poisson":
return -self.model.poisson_log_likelihood(state)
else:
raise ValueError(f"Unknown likelihood type: {self.likelihood}")
[docs]
def fit(self) -> BaseOptimizer:
"""Perform the Slalom optimization."""
grad_func = backend.grad(self.density)
momentum = backend.zeros_like(self.current_state)
self.S_history = [self.S]
self.loss_history = [self.density(self.current_state).item()]
self.lambda_history = [backend.to_numpy(self.current_state)]
self.start_fit = time()
for i in range(self.max_iter):
try:
# Perform the Slalom step
vstate = self.model.to_valid(self.current_state)
with ValidContext(self.model):
self.S, loss, grad = func.slalom_step(
self.density, grad_func, vstate, m=momentum, S=self.S
)
self.current_state = self.model.from_valid(
vstate - self.S * (grad + momentum) / backend.linalg.norm(grad + momentum)
)
momentum = self.momentum * (momentum + grad)
except OptimizeStopSuccess as e:
self.message = self.message + str(e)
break
except OptimizeStopFail as e:
if backend.allclose(momentum, backend.zeros_like(momentum)):
self.message = self.message + str(e)
break
momentum = backend.zeros_like(self.current_state)
continue
# Log the loss
self.S_history.append(self.S)
self.loss_history.append(loss)
self.lambda_history.append(backend.to_numpy(self.current_state))
if self.verbose > 0 and (i % int(self.report_freq) == 0 or i == self.max_iter - 1):
config.logger.info(
f"iter: {i}, step size: {self.S:.6e}, posterior density: {loss:.6e}"
)
if len(self.loss_history) >= 5:
relative_loss = (self.loss_history[-5] - self.loss_history[-1]) / self.loss_history[
-1
]
if relative_loss < self.relative_tolerance:
self.message = self.message + " success"
break
else:
self.message = self.message + " fail. max iteration reached"
# Set the model parameters to the best values from the fit
self.model.set_values(
backend.as_array(self.res(), dtype=config.DTYPE, device=config.DEVICE)
)
if self.verbose > 0:
config.logger.info(
f"Slalom Fitting complete in {time() - self.start_fit} sec with message: {self.message}"
)
return self