Source code for astrophot.fit.mhmcmc

# Metropolis-Hasting Markov-Chain Monte-Carlo
from typing import Optional, Sequence

import numpy as np

try:
    import emcee
except ImportError:
    emcee = None

from .base import BaseOptimizer
from ..models import Model
from .. import config
from ..backend_obj import backend

__all__ = ("MHMCMC",)


[docs] class MHMCMC(BaseOptimizer): """Metropolis-Hastings Markov-Chain Monte-Carlo sampler, based on: https://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm . This is simply a thin wrapper for the Emcee package, which is a well-known MCMC sampler. Note that the Emcee sampler requires multiple walkers to sample the parameter space efficiently. The number of walkers is set to twice the number of parameters by default, but can be made higher (not lower) if desired. This is done by passing a 2D array of shape (nwalkers, ndim) to the `fit` method. :param likelihood: The likelihood function to use for the MCMC sampling. Can be "gaussian" or "poisson". Default is "gaussian". """ def __init__( self, model: Model, initial_state: Optional[Sequence] = None, max_iter: int = 1000, likelihood="gaussian", **kwargs, ): super().__init__(model, initial_state, max_iter=max_iter, **kwargs) if emcee is None: raise ImportError( "The emcee package is required for MHMCMC sampling. Please install it with `pip install emcee` or the like." ) self.likelihood = likelihood self.chain = []
[docs] def density(self): """ Returns the density of the model at the given state vector. This is used to calculate the likelihood of the model at the given state. """ if self.likelihood == "gaussian": vll = backend.vmap(self.model.gaussian_log_likelihood) elif self.likelihood == "poisson": vll = backend.vmap(self.model.poisson_log_likelihood) else: raise ValueError(f"Unknown likelihood type: {self.likelihood}") def dens(state: np.ndarray) -> np.ndarray: state = backend.as_array(state, dtype=config.DTYPE, device=config.DEVICE) return backend.to_numpy(vll(state)) return dens
[docs] def fit( self, state: Optional[np.ndarray] = None, nsamples: Optional[int] = None, restart_chain: bool = True, skip_initial_state_check: bool = True, flat_chain: bool = True, ): """ Performs the MCMC sampling using a Metropolis Hastings acceptance step and records the chain for later examination. """ if nsamples is None: nsamples = self.max_iter if state is None: state = self.current_state if len(state.shape) == 1: nwalkers = state.shape[0] * 2 state = state * np.random.normal(loc=1, scale=0.01, size=(nwalkers, state.shape[0])) else: nwalkers = state.shape[0] ndim = state.shape[1] sampler = emcee.EnsembleSampler(nwalkers, ndim, self.density(), vectorize=True) state = sampler.run_mcmc(state, nsamples, skip_initial_state_check=skip_initial_state_check) if restart_chain: self.chain = sampler.get_chain(flat=flat_chain) else: self.chain = np.append(self.chain, sampler.get_chain(flat=flat_chain), axis=0) self.model.set_values( backend.as_array(self.chain[-1], dtype=config.DTYPE, device=config.DEVICE) ) return self