Source code for astrophot.fit.batch_lm

import numpy as np
from ..models import Model
from ..image import TargetImageBatch, WindowBatch
from .base import BaseOptimizer
from ..backend_obj import backend, ArrayLike
from .. import config
from ..errors import OptimizeStopSuccess
from ..param import ValidContext
from . import func


[docs] class BatchLM(BaseOptimizer): def __init__( self, model: Model, batch_target: TargetImageBatch, batch_window: WindowBatch, max_iter: int = 100, relative_tolerance: float = 1e-5, Lup=11.0, Ldn=9.0, L0=1.0, max_step_iter: int = 3, likelihood="gaussian", **kwargs, ): super().__init__( model=model, initial_state=model.get_values(), max_iter=max_iter, relative_tolerance=relative_tolerance, **kwargs, ) self.max_step_iter = max_step_iter # Likelihood self.likelihood = likelihood if self.likelihood not in ["gaussian", "poisson"]: raise ValueError( f"Unsupported likelihood: {self.likelihood}, should be one of: 'gaussian' or 'poisson'" ) # mask mask = backend.flatten(batch_target[batch_window].mask, 1, -1) self.mask = ~mask if backend.sum(self.mask).item() == 0: raise OptimizeStopSuccess("No data to fit. All pixels are masked") # data self.data = backend.flatten(batch_target[batch_window].data, 1, -1) # Weight self.weight = backend.flatten(batch_target[batch_window].weight, 1, -1) # WCS crtan = batch_target.crtan shift = backend.as_array( batch_window.origin_shifter(self.model.window), dtype=config.DTYPE, device=config.DEVICE ) crpix = batch_target[batch_window].crpix + shift CD = batch_target.CD psf = batch_target.psf_stack psf_batch = None if psf is None else 0 # Forward vmodel = backend.vmap( lambda cd, crt, crp, psf, params: backend.flatten( self.model(cd, crt, crp, psf, params=params).data ), in_dims=(0, 0, 0, psf_batch, 0), ) self.forward = lambda x: vmodel(CD, crtan, crpix, psf, x) # Jacobian vjac = backend.vmap( backend.jacfwd( lambda cd, crt, crp, psf, params: backend.flatten( self.model(cd, crt, crp, psf, params=params).data ), argnums=4, ), in_dims=(0, 0, 0, psf_batch, 0), ) self.jacobian = lambda x: vjac(CD, crtan, crpix, psf, x) # ndf self.ndf = backend.clamp( backend.sum(self.mask, dim=1) - self.current_state.shape[1], backend.as_array(1), None ) # LM parameters self.Lup = Lup self.Ldn = Ldn self.L = L0 * backend.ones( self.current_state.shape[0], dtype=config.DTYPE, device=config.DEVICE )
[docs] def chi2_ndf(self): return ( backend.sum( self.weight * self.mask * (self.data - self.forward(self.current_state)) ** 2, dim=1, ) / self.ndf )
[docs] def poisson_2nll_ndf(self): M = self.forward(self.current_state) return ( 2 * backend.sum((M - self.data * backend.log(M + 1e-10)) * self.mask, dim=1) / self.ndf )
[docs] def fit(self, update_uncertainty=True): if self.current_state.shape[1] == 0: if self.verbose > 0: config.logger.warning("No parameters to optimize. Exiting fit") self.message = "No parameters to optimize. Exiting fit" return self if self.likelihood == "gaussian": quantity = "Chi^2/DoF" self.loss_history = [backend.to_numpy(self.chi2_ndf())] elif self.likelihood == "poisson": quantity = "2NLL/DoF" self.loss_history = [backend.to_numpy(self.poisson_2nll_ndf())] self._covariance_matrix = None self.L_history = [backend.to_numpy(self.L)] self.lambda_history = [backend.to_numpy(backend.copy(self.current_state))] if self.verbose > 0: config.logger.info( f"==Starting LM fit for '{self.model.name}' with batch of {self.current_state.shape[0]} images with {self.current_state.shape[1]} dynamic parameters and {self.data.shape[1]} pixels==" ) for _ in range(self.max_iter): if self.verbose > 0: config.logger.info(f"{quantity}: {self.loss_history[-1]}, L: {self.L_history[-1]}") if self.fit_valid: with ValidContext(self.model): res = func.batch_lm_step( x=self.model.to_valid(self.current_state), data=self.data, model=self.forward, weight=self.weight, mask=self.mask, jacobian=self.jacobian, L=self.L, Lup=self.Lup, Ldn=self.Ldn, likelihood=self.likelihood, max_step_iter=self.max_step_iter, ) self.current_state = self.model.from_valid(backend.copy(res["x"])) else: res = func.batch_lm_step( x=self.current_state, data=self.data, model=self.forward, weight=self.weight, mask=self.mask, jacobian=self.jacobian, L=self.L, Lup=self.Lup, Ldn=self.Ldn, likelihood=self.likelihood, max_step_iter=self.max_step_iter, ) self.current_state = backend.copy(res["x"]) self.L = backend.clamp(res["L"], backend.as_array(1e-9), backend.as_array(1e9)) self.L_history.append(backend.to_numpy(self.L)) self.loss_history.append(2 * res["nll"] / backend.to_numpy(self.ndf)) self.lambda_history.append(backend.to_numpy(backend.copy(self.current_state))) if self.check_convergence(): break else: self.message = self.message + "fail. Maximum iterations" if self.verbose > 0: config.logger.info( f"Final {quantity}: {self.loss_history[-1]}, L: {self.L_history[-1]}. Converged: {self.message}" ) self.model.set_values(self.current_state) if update_uncertainty: self.update_uncertainty() return self
[docs] def check_convergence(self) -> bool: """Check if the optimization has converged based on the last iteration's chi^2 and the relative tolerance. """ if len(self.loss_history) < 3: return False if np.all( (self.loss_history[-2] - self.loss_history[-1]) / self.loss_history[-1] < self.relative_tolerance ) and np.all(backend.to_numpy(self.L) < 0.1): self.message = self.message + "success" return True if len(self.loss_history) < 10: return False if np.all( (self.loss_history[-10] - self.loss_history[-1]) / self.loss_history[-1] < self.relative_tolerance ): self.message = self.message + "success by immobility. Convergence not guaranteed" return True return False
@property def covariance_matrix(self) -> ArrayLike: """The covariance matrix for the model at the current parameters. This can be used to construct a full Gaussian PDF for the parameters using: $\\mathcal{N}(\\mu,\\Sigma)$ where $\\mu$ is the optimized parameters and $\\Sigma$ is the covariance matrix. """ if self._covariance_matrix is not None: return self._covariance_matrix J = self.jacobian(self.current_state) * self.mask.reshape(self.mask.shape + (1,)) if self.likelihood == "gaussian": hess = backend.vmap(func.hessian)(J, self.weight * self.mask) elif self.likelihood == "poisson": hess = backend.vmap(func.hessian_poisson)( J, self.data * self.mask, self.forward(self.current_state) * self.mask ) try: self._covariance_matrix = backend.vmap(backend.linalg.inv)(hess) except: config.logger.warning( "WARNING: Hessian is singular, likely at least one parameter is non-physical. Will use pseudo-inverse of Hessian to continue but results should be inspected." ) self._covariance_matrix = backend.vmap(backend.linalg.pinv)(hess) return self._covariance_matrix
[docs] def update_uncertainty(self) -> None: """Call this function after optimization to set the uncertainties for the parameters. This will use the diagonal of the covariance matrix to update the uncertainties. See the covariance_matrix function for the full representation of the uncertainties. """ # set the uncertainty for each parameter cov = self.covariance_matrix if backend.all(backend.isfinite(cov)): try: self.model.set_values( backend.sqrt(backend.abs(backend.vmap(backend.diag)(cov))), attribute="uncertainty", ) except RuntimeError as e: config.logger.warning(f"Unable to update uncertainty due to: {e}") else: config.logger.warning( "Unable to update uncertainty due to non finite covariance matrix" )