Source code for astrophot.plots.diagnostic

import numpy as np

from matplotlib.patches import Ellipse
from matplotlib import pyplot as plt
from scipy.stats import norm
from .visuals import main_pallet

__all__ = ("covariance_matrix",)


[docs] def covariance_matrix( covariance_matrix, mean, labels=None, figsize=(10, 10), reference_values=None, ellipse_colors=main_pallet["primary1"], showticks=True, **kwargs, ): """ Create a covariance matrix plot. Creates a corner plot with ellipses representing the covariance between parameters. :param covariance_matrix: Covariance matrix of shape (n_params, n_params). :type covariance_matrix: np.ndarray :param mean: Mean values of the parameters, shape (n_params,). :type mean: np.ndarray :param labels: Labels for the parameters. :type labels: list, optional :param figsize: Size of the figure. Default is (10, 10). :type figsize: tuple, optional :param reference_values: Reference values for the parameters, used to draw vertical and horizontal lines. Typically these are the true values of the parameters. :type reference_values: np.ndarray, optional :param ellipse_colors: Color for the ellipses. Default is `main_pallet["primary1"]`. :type ellipse_colors: str or list, optional :param showticks: Whether to show ticks on the axes. Default is True. :type showticks: bool, optional returns the fig and ax objects created to allow further customization by the user. """ num_params = covariance_matrix.shape[0] fig, axes = plt.subplots(num_params, num_params, figsize=figsize) plt.subplots_adjust(wspace=0.0, hspace=0.0) for i in range(num_params): for j in range(num_params): ax = axes[i, j] if i == j: x = np.linspace( mean[i] - 3 * np.sqrt(covariance_matrix[i, i]), mean[i] + 3 * np.sqrt(covariance_matrix[i, i]), 100, ) y = norm.pdf(x, mean[i], np.sqrt(covariance_matrix[i, i])) ax.plot(x, y, color=ellipse_colors, lw=1.5) ax.set_xlim( mean[i] - 3 * np.sqrt(covariance_matrix[i, i]), mean[i] + 3 * np.sqrt(covariance_matrix[i, i]), ) if reference_values is not None: ax.axvline(reference_values[i], color=main_pallet["pop"], linestyle="-", lw=1) elif j < i: cov = covariance_matrix[np.ix_([j, i], [j, i])] lambda_, v = np.linalg.eig(cov) lambda_ = np.sqrt(lambda_) angle = np.rad2deg(np.arctan2(v[1, 0], v[0, 0])) for k in [1, 2]: ellipse = Ellipse( xy=(mean[j], mean[i]), width=lambda_[0] * k * 2, height=lambda_[1] * k * 2, angle=angle, edgecolor=ellipse_colors, facecolor="none", lw=1.5, ) ax.add_artist(ellipse) # Set axis limits margin = 3 ax.set_xlim( mean[j] - margin * np.sqrt(covariance_matrix[j, j]), mean[j] + margin * np.sqrt(covariance_matrix[j, j]), ) ax.set_ylim( mean[i] - margin * np.sqrt(covariance_matrix[i, i]), mean[i] + margin * np.sqrt(covariance_matrix[i, i]), ) if reference_values is not None: ax.axvline(reference_values[j], color=main_pallet["pop"], linestyle="-", lw=1) ax.axhline(reference_values[i], color=main_pallet["pop"], linestyle="-", lw=1) if j > i: ax.axis("off") if i < num_params - 1: ax.set_xticklabels([]) else: if labels is not None: ax.set_xlabel(labels[j]) if not showticks: ax.yaxis.set_major_locator(plt.NullLocator()) if j > 0: ax.set_yticklabels([]) else: if labels is not None: ax.set_ylabel(labels[i]) if not showticks: ax.xaxis.set_major_locator(plt.NullLocator()) return fig, ax
if __name__ == "__main__": fig, ax = covariance_matrix(np.array([[4, -2], [-2, 4]]), np.array([0, 0])) plt.show()