Source code for astrophot.fit.func.mala

import numpy as np
from tqdm import tqdm


[docs] def mala( initial_state, # (num_chains, D) log_prob, # (num_chains, D) -> (num_chains,) log_prob_grad, # (num_chains, D) -> (num_chains, D) num_samples, epsilon, mass_matrix, # covariance progress=True, desc="MALA", ): x = np.array(initial_state, copy=True) C, D = x.shape # mass, inv_mass, L mass = np.array(mass_matrix, copy=False) # (D, D) inv_mass = np.linalg.inv(mass) # (D, D) L = np.linalg.cholesky(mass) # (D, D) samples = np.zeros((num_samples, C, D), dtype=x.dtype) # (N, C, D) acceptance_rate = np.zeros([0]) # (0,) logp = np.zeros((num_samples, C), dtype=x.dtype) # (N, C) # Cache current state logp_cur = log_prob(x) # (C,) grad_cur = log_prob_grad(x) # (C, D) # Random number generator rng = np.random.default_rng(np.random.randint(1e9)) it = range(num_samples) if progress: it = tqdm(it, desc=desc, position=0, leave=True) for t in it: # proposal using current grad mu_x = 0.5 * (epsilon**2) * (grad_cur @ mass) # (C, D) noise = rng.standard_normal((C, D)) @ L.T # (C, D) x_prop = x + mu_x + epsilon * noise # (C, D) # Evaluate proposal logp_prop = log_prob(x_prop) # (C,) grad_prop = log_prob_grad(x_prop) # (C, D) mu_xprop = 0.5 * (epsilon**2) * (grad_prop @ mass) # (C, D) # q(x|x') \propto \exp(-0.5|x - x' - mu(x')|^2 / \epsilon^2) d1 = x - x_prop - mu_xprop # for q(x | x') d2 = x_prop - x - mu_x # for q(x'| x) logq1 = -0.5 * np.einsum("bi,ij,bj->b", d1, inv_mass, d1) / epsilon**2 # (C,) logq2 = -0.5 * np.einsum("bi,ij,bj->b", d2, inv_mass, d2) / epsilon**2 # (C,) log_alpha = (logp_prop - logp_cur) + (logq1 - logq2) # (C,) accept = np.log(rng.random(C)) < log_alpha # (C,) acceptance_rate = np.concatenate([acceptance_rate, accept]) # Update all three pieces in-place where accepted x[accept] = x_prop[accept] # (C, D) logp_cur[accept] = logp_prop[accept] # (C,) grad_cur[accept] = grad_prop[accept] # (C, D) samples[t] = x.copy() logp[t] = logp_cur.copy() if progress: it.set_postfix(acc_rate=f"{acceptance_rate.mean():0.2f}") return samples, logp