Source code for astrophot.fit.hmc

# Hamiltonian Monte-Carlo
from typing import Optional, Sequence

import torch

try:
    import pyro
    import pyro.distributions as dist
    from pyro.distributions import Distribution
    from pyro.infer import MCMC as pyro_MCMC
    from pyro.infer import HMC as pyro_HMC
    from pyro.infer.mcmc.adaptation import BlockMassMatrix
    from pyro.ops.welford import WelfordCovariance
except ImportError:
    pyro = None
    Distribution = None

from .base import BaseOptimizer
from ..models import Model
from .. import config

__all__ = ("HMC",)


###########################################
# !Overwrite pyro configuration behavior!
# currently this is the only way to provide
# mass matrix manually
###########################################
def new_configure(self, mass_matrix_shape, adapt_mass_matrix=True, options={}):
    """
    Sets up an initial mass matrix.

    :param mass_matrix_shape: a dict that maps tuples of site names to the shape of
        the corresponding mass matrix. Each tuple of site names corresponds to a block.
    :param adapt_mass_matrix: a flag to decide whether an adaptation scheme will be used.
    :param options: Array options to construct the initial mass matrix.
    """
    inverse_mass_matrix = {}
    for site_names, shape in mass_matrix_shape.items():
        self._mass_matrix_size[site_names] = shape[0]
        diagonal = len(shape) == 1
        inverse_mass_matrix[site_names] = (
            torch.full(shape, self._init_scale, **options)
            if diagonal
            else torch.eye(*shape, **options) * self._init_scale
        )
        if adapt_mass_matrix:
            adapt_scheme = WelfordCovariance(diagonal=diagonal)
            self._adapt_scheme[site_names] = adapt_scheme

    if len(self.inverse_mass_matrix.keys()) == 0:
        self.inverse_mass_matrix = inverse_mass_matrix


if pyro is not None:
    BlockMassMatrix.configure = new_configure
############################################


[docs] class HMC(BaseOptimizer): """Hamiltonian Monte-Carlo sampler wrapper for the Pyro package. This MCMC algorithm uses gradients of the $\\chi^2$ to more efficiently explore the probability distribution. More information on HMC can be found at: https://en.wikipedia.org/wiki/Hamiltonian_Monte_Carlo, https://arxiv.org/abs/1701.02434, and http://www.mcmchandbook.net/HandbookChapter5.pdf :param max_iter: The number of sampling steps to perform. Defaults to 1000. :type max_iter: int, optional :param epsilon: The length of the integration step to perform for each leapfrog iteration. The momentum update will be of order epsilon * score. Defaults to 1e-5. :type epsilon: float, optional :param leapfrog_steps: Number of steps to perform with leapfrog integrator per sample of the HMC. Defaults to 10. :type leapfrog_steps: int, optional :param inv_mass: Inverse Mass matrix (covariance matrix) which can tune the behavior in each dimension to ensure better mixing when sampling. Defaults to the identity. :type inv_mass: float or array, optional :param progress_bar: Whether to display a progress bar during sampling. Defaults to True. :type progress_bar: bool, optional :param prior: Prior distribution for the parameters. Defaults to None. :type prior: distribution, optional :param warmup: Number of warmup steps before actual sampling begins. Defaults to 100. :type warmup: int, optional :param hmc_kwargs: Additional keyword arguments for the HMC sampler. Defaults to {}. :type hmc_kwargs: dict, optional :param mcmc_kwargs: Additional keyword arguments for the MCMC process. Defaults to {}. :type mcmc_kwargs: dict, optional """ def __init__( self, model: Model, initial_state: Optional[Sequence] = None, max_iter: int = 1000, inv_mass: Optional[torch.Tensor] = None, epsilon: float = 1e-4, leapfrog_steps: int = 10, progress_bar: bool = True, prior: Optional["Distribution"] = None, warmup: int = 100, hmc_kwargs: dict = {}, mcmc_kwargs: dict = {}, likelihood: str = "gaussian", **kwargs, ): if pyro is None: raise ImportError("Pyro must be installed to use HMC.") super().__init__(model, initial_state, max_iter=max_iter, **kwargs) self.inv_mass = inv_mass self.epsilon = epsilon self.leapfrog_steps = leapfrog_steps self.progress_bar = progress_bar self.prior = prior self.warmup = warmup self.hmc_kwargs = hmc_kwargs self.mcmc_kwargs = mcmc_kwargs self.likelihood = likelihood self.acceptance = None
[docs] def fit( self, state: Optional[torch.Tensor] = None, ): """Performs MCMC sampling using Hamiltonian Monte-Carlo step. Records the chain for later examination. :param state: Model parameters as a 1D Array. :type state: Array, optional """ def step(model, prior): x = pyro.sample("x", prior) # Log-likelihood function if self.likelihood == "gaussian": log_likelihood_value = model.gaussian_log_likelihood(params=x) elif self.likelihood == "poisson": log_likelihood_value = model.poisson_log_likelihood(params=x) else: raise ValueError(f"Unsupported likelihood type: {self.likelihood}") # Observe the log-likelihood pyro.factor("obs", log_likelihood_value) if self.prior is None: self.prior = dist.Normal( self.current_state, torch.ones_like(self.current_state) * 1e2 + torch.abs(self.current_state) * 1e2, ) # Set up the HMC sampler hmc_kwargs = { "jit_compile": False, "ignore_jit_warnings": True, "full_mass": True, "step_size": self.epsilon, "num_steps": self.leapfrog_steps, "adapt_step_size": False, "adapt_mass_matrix": self.inv_mass is None, } hmc_kwargs.update(self.hmc_kwargs) hmc_kernel = pyro_HMC(step, **hmc_kwargs) if self.inv_mass is not None: hmc_kernel.mass_matrix_adapter.inverse_mass_matrix = {("x",): self.inv_mass} # Provide an initial guess for the parameters init_params = {"x": self.model.get_values()} # Run MCMC with the HMC sampler and the initial guess mcmc_kwargs = { "num_samples": self.max_iter, "warmup_steps": self.warmup, "initial_params": init_params, "disable_progbar": not self.progress_bar, } mcmc_kwargs.update(self.mcmc_kwargs) mcmc = pyro_MCMC(hmc_kernel, **mcmc_kwargs) mcmc.run(self.model, self.prior) self.iteration += self.max_iter # Extract posterior samples chain = mcmc.get_samples()["x"] self.chain = chain self.model.set_values( torch.as_tensor(self.chain[-1], dtype=config.DTYPE, device=config.DEVICE) ) return self