Source code for astrophot.fit.mala
# Metropolis-Adjusted Langevin Algorithm sampler
from typing import Optional, Sequence
import numpy as np
from .base import BaseOptimizer
from ..models import Model
from .. import config
from ..backend_obj import backend
from . import func
__all__ = ("MALA",)
[docs]
class MALA(BaseOptimizer):
"""Metropolis-Adjusted Langevin Algorithm (MALA) sampler, based on:
https://en.wikipedia.org/wiki/Metropolis-adjusted_Langevin_algorithm . This
is a gradient-based MCMC sampler that uses the gradient of the
log-likelihood to propose new samples. These gradient based proposals can
lead to more efficient sampling of the parameter space. This is especially
true when the mass_matrix is set well. A good guess for the mass matrix is
the covariance matrix of the likelihood at the maximum likelihood point.
Which can be found fairly easily with the LM optimizer (see the fitting
methods tutorial).
:param chains: The number of MCMC chains to run in parallel. Default is 4.
:param epsilon: The step size for the MALA sampler. Default is 1e-2.
:param mass_matrix: The mass matrix for the MALA sampler. If None, the identity matrix is used.
:param progress_bar: Whether to show a progress bar during sampling. Default is True.
:param likelihood: The likelihood function to use for the MCMC sampling. Can be "gaussian" or "poisson". Default is "gaussian".
"""
def __init__(
self,
model: Model,
initial_state: Optional[Sequence] = None,
chains=4,
epsilon: float = 1e-2,
mass_matrix: Optional[np.ndarray] = None,
max_iter: int = 1000,
progress_bar: bool = True,
likelihood="gaussian",
**kwargs,
):
super().__init__(model, initial_state, max_iter=max_iter, **kwargs)
self.chain = []
if len(self.current_state.shape) == 2:
self.chains = self.current_state.shape[0]
else:
self.chains = chains
self.likelihood = likelihood
self.epsilon = epsilon
self.mass_matrix = mass_matrix
self.progress_bar = progress_bar
[docs]
def density_func(self):
"""
Returns the density of the model at the given state vector.
This is used to calculate the likelihood of the model at the given state.
"""
if self.likelihood == "gaussian":
vll = backend.vmap(self.model.gaussian_log_likelihood)
elif self.likelihood == "poisson":
vll = backend.vmap(self.model.poisson_log_likelihood)
else:
raise ValueError(f"Unknown likelihood type: {self.likelihood}")
def dens(state: np.ndarray) -> np.ndarray:
state = backend.as_array(state, dtype=config.DTYPE, device=config.DEVICE)
return backend.to_numpy(vll(state))
return dens
[docs]
def density_grad_func(self):
"""
Returns the gradient of the density of the model at the given state vector.
This is used to calculate the gradient of the likelihood of the model at the given state.
"""
if self.likelihood == "gaussian":
vll_grad = backend.vmap(backend.grad(self.model.gaussian_log_likelihood))
elif self.likelihood == "poisson":
vll_grad = backend.vmap(backend.grad(self.model.poisson_log_likelihood))
else:
raise ValueError(f"Unknown likelihood type: {self.likelihood}")
def grad(state: np.ndarray) -> np.ndarray:
state = backend.as_array(state, dtype=config.DTYPE, device=config.DEVICE)
return backend.to_numpy(vll_grad(state))
return grad
[docs]
def fit(self):
Px = self.density_func()
dPdx = self.density_grad_func()
initial_state = backend.to_numpy(self.current_state)
if len(initial_state.shape) == 1:
initial_state = np.repeat(initial_state[None, :], self.chains, axis=0)
if self.mass_matrix is None:
D = initial_state.shape[1]
self.mass_matrix = np.eye(D, dtype=initial_state.dtype)
self.chain, self.logp = func.mala(
initial_state,
Px,
dPdx,
self.max_iter,
self.epsilon,
self.mass_matrix,
progress=self.progress_bar,
desc="MALA",
)
# Fill model with max logp sample
max_logp_index = np.argmax(self.logp)
max_logp_index = np.unravel_index(max_logp_index, self.logp.shape)
self.model.set_values(
backend.as_array(self.chain[max_logp_index], dtype=config.DTYPE, device=config.DEVICE)
)
return self