import numpy as np
from matplotlib.patches import Ellipse
from matplotlib import pyplot as plt
from scipy.stats import norm
from .visuals import main_pallet
__all__ = ("covariance_matrix",)
[docs]
def covariance_matrix(
covariance_matrix,
mean,
labels=None,
figsize=(10, 10),
reference_values=None,
ellipse_colors=main_pallet["primary1"],
showticks=True,
**kwargs,
):
"""
Create a covariance matrix plot. Creates a corner plot with ellipses representing the covariance between parameters.
**Args:**
- `covariance_matrix` (np.ndarray): Covariance matrix of shape (n_params, n_params).
- `mean` (np.ndarray): Mean values of the parameters, shape (n_params,).
- `labels` (list, optional): Labels for the parameters.
- `figsize` (tuple, optional): Size of the figure. Default is (10, 10).
- `reference_values` (np.ndarray, optional): Reference values for the parameters, used to draw vertical and horizontal lines. Typically these are the true values of the parameters.
- `ellipse_colors` (str or list, optional): Color for the ellipses. Default is `main_pallet["primary1"]`.
- `showticks` (bool, optional): Whether to show ticks on the axes. Default is True.
returns the fig and ax objects created to allow further customization by the user.
"""
num_params = covariance_matrix.shape[0]
fig, axes = plt.subplots(num_params, num_params, figsize=figsize)
plt.subplots_adjust(wspace=0.0, hspace=0.0)
for i in range(num_params):
for j in range(num_params):
ax = axes[i, j]
if i == j:
x = np.linspace(
mean[i] - 3 * np.sqrt(covariance_matrix[i, i]),
mean[i] + 3 * np.sqrt(covariance_matrix[i, i]),
100,
)
y = norm.pdf(x, mean[i], np.sqrt(covariance_matrix[i, i]))
ax.plot(x, y, color=ellipse_colors, lw=1.5)
ax.set_xlim(
mean[i] - 3 * np.sqrt(covariance_matrix[i, i]),
mean[i] + 3 * np.sqrt(covariance_matrix[i, i]),
)
if reference_values is not None:
ax.axvline(reference_values[i], color=main_pallet["pop"], linestyle="-", lw=1)
elif j < i:
cov = covariance_matrix[np.ix_([j, i], [j, i])]
lambda_, v = np.linalg.eig(cov)
lambda_ = np.sqrt(lambda_)
angle = np.rad2deg(np.arctan2(v[1, 0], v[0, 0]))
for k in [1, 2]:
ellipse = Ellipse(
xy=(mean[j], mean[i]),
width=lambda_[0] * k * 2,
height=lambda_[1] * k * 2,
angle=angle,
edgecolor=ellipse_colors,
facecolor="none",
lw=1.5,
)
ax.add_artist(ellipse)
# Set axis limits
margin = 3
ax.set_xlim(
mean[j] - margin * np.sqrt(covariance_matrix[j, j]),
mean[j] + margin * np.sqrt(covariance_matrix[j, j]),
)
ax.set_ylim(
mean[i] - margin * np.sqrt(covariance_matrix[i, i]),
mean[i] + margin * np.sqrt(covariance_matrix[i, i]),
)
if reference_values is not None:
ax.axvline(reference_values[j], color=main_pallet["pop"], linestyle="-", lw=1)
ax.axhline(reference_values[i], color=main_pallet["pop"], linestyle="-", lw=1)
if j > i:
ax.axis("off")
if i < num_params - 1:
ax.set_xticklabels([])
else:
if labels is not None:
ax.set_xlabel(labels[j])
if not showticks:
ax.yaxis.set_major_locator(plt.NullLocator())
if j > 0:
ax.set_yticklabels([])
else:
if labels is not None:
ax.set_ylabel(labels[i])
if not showticks:
ax.xaxis.set_major_locator(plt.NullLocator())
return fig, ax
if __name__ == "__main__":
fig, ax = covariance_matrix(np.array([[4, -2], [-2, 4]]), np.array([0, 0]))
plt.show()