Source code for astrophot.fit.iterative

# Apply a different optimizer iteratively
from typing import Dict, Any, Union, Sequence, Literal
from time import time
from functools import partial

from caskade import ValidContext
import numpy as np
import torch

from .base import BaseOptimizer
from ..models import Model
from .lm import LM
from .. import config
from ..backend_obj import backend
from ..errors import OptimizeStopSuccess, OptimizeStopFail
from . import func

__all__ = [
    "Iter",
    # "Iter_LM"
]


[docs] class Iter(BaseOptimizer): """Optimizer wrapper that performs optimization iteratively. This optimizer applies the LM optimizer to a group model iteratively one model at a time. It can be used for complex fits or when the number of models to fit is too large to fit in memory. Note that it will iterate through the group model, but if models within the group are themselves group models, then they will be optimized as a whole. This gives some flexibility to structure the models in a useful way. If not given, the `lm_kwargs` will be set to a relative tolerance of 1e-3 and a maximum of 15 iterations. This is to allow for faster convergence, it is not worthwhile for a single model to spend lots of time optimizing when its neighbors havent converged. **Args:** - `max_iter`: Maximum number of iterations, defaults to 100. - `lm_kwargs`: Keyword arguments to pass to `LM` optimizer. """ def __init__( self, model: Model, initial_state: np.ndarray = None, max_iter: int = 100, lm_kwargs: Dict[str, Any] = {"verbose": 0}, **kwargs: Dict[str, Any], ): super().__init__(model, initial_state, max_iter=max_iter, **kwargs) self.current_state = model.get_values() self.lm_kwargs = lm_kwargs if "relative_tolerance" not in lm_kwargs: # Lower tolerance since it's not worth fine tuning a model when its neighbors will be shifting soon anyway self.lm_kwargs["relative_tolerance"] = 1e-3 self.lm_kwargs["max_iter"] = 15 # # pixels # parameters self.ndf = self.model.target[self.model.window].flatten("data").shape[0] - len( self.current_state ) # subtract masked pixels from degrees of freedom self.ndf -= backend.sum(self.model.target[self.model.window].flatten("mask")).item()
[docs] def sub_step(self, model: Model, update_uncertainty=False): """ Perform optimization for a single model. """ self.Y -= model() initial_values = model.target.copy() model.target = model.target - self.Y res = LM(model, **self.lm_kwargs).fit(update_uncertainty=update_uncertainty) self.Y += model() if self.verbose > 1: config.logger.info(res.message) model.target = initial_values
[docs] def step(self): """ Perform a single iteration of optimization. """ if self.verbose > 0: config.logger.info("--------iter-------") # Fit each model individually for model in self.model.models: if self.verbose > 0: config.logger.info(model.name) self.sub_step(model) # Update the current state self.current_state = self.model.get_values() # Update the loss value with torch.no_grad(): if self.verbose > 0: config.logger.info("Update Chi^2 with new parameters") self.Y = self.model(params=self.current_state) D = self.model.target[self.model.window].flatten("data") V = self.model.target[self.model.window].flatten("variance") M = self.model.target[self.model.window].flatten("mask") loss = backend.sum((((D - self.Y.flatten("data")) ** 2) / V)[~M]) / self.ndf if self.verbose > 0: config.logger.info(f"Loss: {loss.item()}") self.lambda_history.append(np.copy(backend.to_numpy(self.current_state))) self.loss_history.append(loss.item()) # Test for convergence if self.iteration >= 2 and ( (-self.relative_tolerance * 1e-3) < ((self.loss_history[-2] - self.loss_history[-1]) / self.loss_history[-1]) < (self.relative_tolerance / 10) ): self._count_finish += 1 else: self._count_finish = 0 self.iteration += 1
[docs] def fit(self) -> BaseOptimizer: """ Perform the iterative fitting process until convergence or maximum iterations reached. """ self.iteration = 0 self.Y = self.model(params=self.current_state) start_fit = time() try: while True: self.step() if self.iteration > 2 and self._count_finish >= 2: self.message = self.message + "success" break elif self.iteration >= self.max_iter: self.message = self.message + f"fail max iterations reached: {self.iteration}" break except KeyboardInterrupt: self.message = self.message + "fail interrupted" self.model.set_values( backend.as_array(self.res(), dtype=config.DTYPE, device=config.DEVICE) ) if self.verbose > 1: config.logger.info( f"Iter Fitting complete in {time() - start_fit} sec with message: {self.message}" ) return self
[docs] class IterParam(BaseOptimizer): """Optimization wrapper that call LM optimizer on subsets of variables. IterParam takes the full set of parameters for a model and breaks them down into chunks as specified by the user. It then calls Levenberg-Marquardt optimization on the subset of parameters, and iterates through all subsets until every parameter has been optimized. It cycles through these chunks until convergence. This method is very powerful in situations where the full optimization problem cannot fit in memory, or where the optimization problem is too complex to tackle as a single large problem. In full LM optimization a single problematic parameter can ripple into issues with every other parameter, so breaking the problem down can sometimes make an otherwise intractable problem easier. For small problems with only a few models, it is likely better to optimize the full problem with LM as, when it works, LM is faster than the IterParam method. Args: chunks (Union[int, tuple]): Specify how to break down the model parameters. If an integer, at each iteration the algorithm will break the parameters into groups of that size. If a tuple, should be a tuple of arrays of length num_dimensions which act as selectors for the parameters to fit (1 to include, 0 to exclude). Default: 50 chunk_order (str): How to iterate through the chunks. Should be one of: random, sequential. Default: sequential """ def __init__( self, model: Model, initial_state: Sequence = None, chunks: Union[int, tuple] = 50, chunk_order: Literal["random", "sequential"] = "sequential", max_iter: int = 100, relative_tolerance: float = 1e-5, Lup=11.0, Ldn=9.0, L0=1.0, max_step_iter: int = 10, ndf=None, W=None, likelihood="gaussian", **kwargs, ): super().__init__( model, initial_state, max_iter=max_iter, relative_tolerance=relative_tolerance, **kwargs, ) # Maximum number of iterations of the algorithm self.max_iter = max_iter # Maximum number of steps while searching for chi^2 improvement on a single jacobian evaluation self.max_step_iter = max_step_iter self.Lup = Lup self.Ldn = Ldn self.L = L0 self.likelihood = likelihood if self.likelihood not in ["gaussian", "poisson"]: raise ValueError(f"Unsupported likelihood: {self.likelihood}") self.chunks = self.make_chunks(chunks) self.chunk_order = chunk_order # mask mask = self.model.target[self.model.window].flatten("mask") self.mask = ~mask if backend.sum(self.mask).item() == 0: raise OptimizeStopSuccess("No data to fit. All pixels are masked") # Initialize optimizer attributes self.Y = self.model.target[self.model.window].flatten("data")[self.mask] # 1 / (sigma^2) if W is not None: self.W = backend.as_array(W, dtype=config.DTYPE, device=config.DEVICE).flatten()[ self.mask ] else: self.W = self.model.target[self.model.window].flatten("weight")[self.mask] # The forward model which computes the output image given input parameters self.full_forward = lambda x: model(params=x).flatten("data")[self.mask] self.forward = [] # Compute the jacobian self.jacobian = [] f = lambda c, state, x: model( params=backend.fill_at_indices(backend.copy(state), self.chunks[c], x), ).flatten("data")[self.mask] j = backend.jacfwd( lambda c, state, x: self.model( params=backend.fill_at_indices(backend.copy(state), self.chunks[c], x), ).flatten("data")[self.mask], argnums=2, ) for c in range(len(self.chunks)): self.forward.append(partial(f, c)) self.jacobian.append(partial(j, c)) # variable to store covariance matrix if it is ever computed self._covariance_matrix = None # Degrees of freedom if ndf is None: self.ndf = max(1.0, len(self.Y) - len(self.current_state)) else: self.ndf = ndf
[docs] def make_chunks(self, chunks): if isinstance(chunks, int): new_chunks = [] for i in range(0, len(self.current_state), chunks): chunk = np.zeros(len(self.current_state), dtype=bool) chunk[i : i + chunks] = True new_chunks.append(chunk) chunks = new_chunks return chunks
[docs] def iter_chunks(self): if self.chunk_order == "random": chunk_ids = list(range(len(self.chunks))) np.random.shuffle(chunk_ids) elif self.chunk_order == "sequential": chunk_ids = list(range(len(self.chunks))) else: raise ValueError( f"Unrecognized chunk_order: {self.chunk_order}. Should be one of: random, sequential" ) return chunk_ids
[docs] def chi2_ndf(self): return ( backend.sum(self.W * (self.Y - self.full_forward(self.current_state)) ** 2) / self.ndf )
[docs] def poisson_2nll_ndf(self): M = self.full_forward(self.current_state) return 2 * backend.sum(M - self.Y * backend.log(M + 1e-10)) / self.ndf
[docs] @torch.no_grad() def fit(self, update_uncertainty=True) -> BaseOptimizer: """This performs the fitting operation. It iterates the LM step function until convergence is reached. Includes a message after fitting to indicate how the fitting exited. Typically if the message returns a "success" then the algorithm found a minimum. This may be the desired solution, or a pathological local minimum, this often depends on the initial conditions. """ if len(self.current_state) == 0: if self.verbose > 0: config.logger.warning("No parameters to optimize. Exiting fit") self.message = "No parameters to optimize. Exiting fit" return self if self.likelihood == "gaussian": quantity = "Chi^2/DoF" self.loss_history = [self.chi2_ndf().item()] elif self.likelihood == "poisson": quantity = "2NLL/DoF" self.loss_history = [self.poisson_2nll_ndf().item()] self._covariance_matrix = None self.L_history = [self.L] self.lambda_history = [backend.to_numpy(backend.copy(self.current_state))] if self.verbose > 0: config.logger.info( f"==Starting LM fit for '{self.model.name}' with {len(self.current_state)} dynamic parameters and {len(self.Y)} pixels==" ) for _ in range(self.max_iter): # Report status if self.verbose > 0: config.logger.info(f"{quantity}: {self.loss_history[-1]:.6g}, L: {self.L:.3g}") # Perform fitting chunk_L = [] for c in self.iter_chunks(): try: if self.fit_valid: with ValidContext(self.model): valid_state = self.model.to_valid(self.current_state) res = func.lm_step( x=valid_state[self.chunks[c]], data=self.Y, model=partial(self.forward[c], valid_state), weight=self.W, jacobian=partial(self.jacobian[c], valid_state), L=self.L, Lup=self.Lup, Ldn=self.Ldn, likelihood=self.likelihood, ) self.current_state = self.model.from_valid( backend.fill_at_indices( valid_state, self.chunks[c], backend.copy(res["x"]) ) ) else: res = func.lm_step( x=self.current_state[self.chunks[c]], data=self.Y, model=partial(self.forward[c], self.current_state), weight=self.W, jacobian=partial(self.jacobian[c], self.current_state), L=self.L, Lup=self.Lup, Ldn=self.Ldn, likelihood=self.likelihood, ) self.current_state = backend.fill_at_indices( self.current_state, self.chunks[c], backend.copy(res["x"]) ) except OptimizeStopFail: if self.verbose > 0: config.logger.warning( f"Could not find step to improve Chi^2 on chunk {c}, moving to next chunk" ) continue except OptimizeStopSuccess as e: continue # success on individual chunk is not enough to stop overall fit chunk_L.append(res["L"]) # Record progress self.L = np.clip(np.max(chunk_L), 1e-9, 1e9) self.L_history.append(self.L) self.loss_history.append(2 * res["nll"] / self.ndf) self.lambda_history.append(backend.to_numpy(backend.copy(self.current_state))) if self.check_convergence(): break else: self.message = self.message + "fail. Maximum iterations" if self.verbose > 0: config.logger.info( f"Final {quantity}: {np.nanmin(self.loss_history):.6g}, L: {self.L_history[np.nanargmin(self.loss_history)]:.3g}. Converged: {self.message}" ) self.model.set_values( backend.as_array(self.res(), dtype=config.DTYPE, device=config.DEVICE) ) if update_uncertainty: self.update_uncertainty() return self
[docs] def check_convergence(self) -> bool: """Check if the optimization has converged based on the last iteration's chi^2 and the relative tolerance. """ if len(self.loss_history) < 3: return False good_history = [self.loss_history[0]] for l in self.loss_history[1:]: if good_history[-1] > l: good_history.append(l) if len(self.loss_history) - len(good_history) >= 10: self.message = self.message + "success by immobility. Convergence not guaranteed" return True if len(good_history) < 3: return False if (good_history[-2] - good_history[-1]) / good_history[ -1 ] < self.relative_tolerance and self.L < 0.1: self.message = self.message + "success" return True if len(good_history) < 10: return False if (good_history[-10] - good_history[-1]) / good_history[-1] < self.relative_tolerance: self.message = self.message + "success by immobility. Convergence not guaranteed" return True return False
@property @torch.no_grad() def covariance_matrix(self): """The covariance matrix for the model at the current parameters. This can be used to construct a full Gaussian PDF for the parameters using: $\\mathcal{N}(\\mu,\\Sigma)$ where $\\mu$ is the optimized parameters and $\\Sigma$ is the covariance matrix. """ if self._covariance_matrix is not None: return self._covariance_matrix N = len(self.current_state) self._covariance_matrix = backend.zeros((N, N), dtype=config.DTYPE, device=config.DEVICE) for c in self.iter_chunks(): J = self.jacobian[c](self.current_state, self.current_state[self.chunks[c]]) if self.likelihood == "gaussian": hess = func.hessian(J, self.W) elif self.likelihood == "poisson": hess = func.hessian_poisson(J, self.Y, self.full_forward(self.current_state)) try: sub_covariance_matrix = backend.linalg.inv(hess) except: config.logger.warning( "WARNING: Hessian is singular, likely at least one parameter is non-physical. Will use pseudo-inverse of Hessian to continue but results should be inspected." ) sub_covariance_matrix = backend.linalg.pinv(hess) ids = backend.meshgrid( backend.as_array(np.arange(N)[self.chunks[c]], dtype=int, device=config.DEVICE), backend.as_array(np.arange(N)[self.chunks[c]], dtype=int, device=config.DEVICE), indexing="ij", ) self._covariance_matrix = backend.fill_at_indices( self._covariance_matrix, (ids[0], ids[1]), sub_covariance_matrix ) return self._covariance_matrix
[docs] @torch.no_grad() def update_uncertainty(self) -> None: """Call this function after optimization to set the uncertainties for the parameters. This will use the diagonal of the covariance matrix to update the uncertainties. See the covariance_matrix function for the full representation of the uncertainties. """ # set the uncertainty for each parameter cov = self.covariance_matrix if backend.all(backend.isfinite(cov)): try: self.model.set_values( backend.sqrt(backend.abs(backend.diag(cov))), attribute="uncertainty" ) except RuntimeError as e: config.logger.warning(f"Unable to update uncertainty due to: {e}") else: config.logger.warning( "Unable to update uncertainty due to non finite covariance matrix" )