Source code for astrophot.plots.diagnostic

import numpy as np

from matplotlib.patches import Ellipse
from matplotlib import pyplot as plt
from scipy.stats import norm

__all__ = ("covariance_matrix",)


[docs] def covariance_matrix( covariance_matrix, mean, labels=None, figsize=(10, 10), reference_values=None, ellipse_colors="g", showticks=True, **kwargs, ): num_params = covariance_matrix.shape[0] fig, axes = plt.subplots(num_params, num_params, figsize=figsize) plt.subplots_adjust(wspace=0.0, hspace=0.0) for i in range(num_params): for j in range(num_params): ax = axes[i, j] if i == j: x = np.linspace( mean[i] - 3 * np.sqrt(covariance_matrix[i, i]), mean[i] + 3 * np.sqrt(covariance_matrix[i, i]), 100, ) y = norm.pdf(x, mean[i], np.sqrt(covariance_matrix[i, i])) ax.plot(x, y, color="g") ax.set_xlim( mean[i] - 3 * np.sqrt(covariance_matrix[i, i]), mean[i] + 3 * np.sqrt(covariance_matrix[i, i]), ) if reference_values is not None: ax.axvline(reference_values[i], color="red", linestyle="-", lw=1) elif j < i: cov = covariance_matrix[np.ix_([j, i], [j, i])] lambda_, v = np.linalg.eig(cov) lambda_ = np.sqrt(lambda_) angle = np.rad2deg(np.arctan2(v[1, 0], v[0, 0])) for k in [1, 2]: ellipse = Ellipse( xy=(mean[j], mean[i]), width=lambda_[0] * k * 2, height=lambda_[1] * k * 2, angle=angle, edgecolor=ellipse_colors, facecolor="none", ) ax.add_artist(ellipse) # Set axis limits margin = 3 ax.set_xlim( mean[j] - margin * np.sqrt(covariance_matrix[j, j]), mean[j] + margin * np.sqrt(covariance_matrix[j, j]), ) ax.set_ylim( mean[i] - margin * np.sqrt(covariance_matrix[i, i]), mean[i] + margin * np.sqrt(covariance_matrix[i, i]), ) if reference_values is not None: ax.axvline(reference_values[j], color="red", linestyle="-", lw=1) ax.axhline(reference_values[i], color="red", linestyle="-", lw=1) if j > i: ax.axis("off") if i < num_params - 1: ax.set_xticklabels([]) else: if labels is not None: ax.set_xlabel(labels[j]) if not showticks: ax.yaxis.set_major_locator(plt.NullLocator()) if j > 0: ax.set_yticklabels([]) else: if labels is not None: ax.set_ylabel(labels[i]) if not showticks: ax.xaxis.set_major_locator(plt.NullLocator()) return fig, ax
if __name__ == "__main__": fig, ax = covariance_matrix(np.array([[4, -2], [-2, 4]]), np.array([0, 0])) plt.show()