import numpy as np
import torch
from astropy.visualization import HistEqStretch, ImageNormalize
from matplotlib.patches import Polygon
import matplotlib
from scipy.stats import iqr
from ..models import Group_Model, PSF_Model
from ..image import Image_List, Window_List
from .. import AP_config
from ..utils.conversions.units import flux_to_sb
from .visuals import *
__all__ = ["target_image", "psf_image", "model_image", "residual_image", "model_window"]
[docs]
def target_image(fig, ax, target, window=None, **kwargs):
"""
This function is used to display a target image using the provided figure and axes.
Args:
fig (matplotlib.figure.Figure): The figure object in which the target image will be displayed.
ax (matplotlib.axes.Axes): The axes object on which the target image will be plotted.
target (Image or Image_List): The image or list of images to be displayed.
window (Window, optional): The window through which the image is viewed. If `None`, the window of the
provided `target` is used. Defaults to `None`.
**kwargs: Arbitrary keyword arguments.
Returns:
fig (matplotlib.figure.Figure): The figure object containing the displayed target image.
ax (matplotlib.axes.Axes): The axes object containing the displayed target image.
Note:
If the `target` is an `Image_List`, this function will recursively call itself for each image in the list.
The `window` parameter and `kwargs` are passed unchanged to each recursive call.
"""
# recursive call for target image list
if isinstance(target, Image_List):
for i in range(len(target.image_list)):
target_image(fig, ax[i], target.image_list[i], window=window, **kwargs)
return fig, ax
if window is None:
window = target.window
if kwargs.get("flipx", False):
ax.invert_xaxis()
target_area = target[window]
dat = np.copy(target_area.data.detach().cpu().numpy())
if target_area.has_mask:
dat[target_area.mask.detach().cpu().numpy()] = np.nan
X, Y = target_area.get_coordinate_corner_meshgrid()
X = X.detach().cpu().numpy()
Y = Y.detach().cpu().numpy()
sky = np.nanmedian(dat)
noise = iqr(dat[np.isfinite(dat)], rng=(16, 84)) / 2
if noise == 0:
noise = np.nanstd(dat)
vmin = sky - 5 * noise
vmax = sky + 5 * noise
if kwargs.get("linear", False):
im = ax.pcolormesh(
X,
Y,
dat,
cmap=cmap_grad,
)
else:
im = ax.pcolormesh(
X,
Y,
dat,
cmap="Greys",
norm=ImageNormalize(
stretch=HistEqStretch(
dat[np.logical_and(dat <= (sky + 3 * noise), np.isfinite(dat))]
),
clip=False,
vmax=sky + 3 * noise,
vmin=np.nanmin(dat),
),
)
im = ax.pcolormesh(
X,
Y,
np.ma.masked_where(dat < (sky + 3 * noise), dat),
cmap=cmap_grad,
norm=matplotlib.colors.LogNorm(),
clim=[sky + 3 * noise, None],
)
ax.axis("equal")
ax.set_xlabel("Tangent Plane X [arcsec]")
ax.set_ylabel("Tangent Plane Y [arcsec]")
return fig, ax
[docs]
@torch.no_grad()
def psf_image(
fig,
ax,
psf,
window=None,
cmap_levels=None,
flipx=False,
**kwargs,
):
if isinstance(psf, PSF_Model):
psf = psf()
# recursive call for target image list
if isinstance(psf, Image_List):
for i in range(len(psf.image_list)):
psf_image(fig, ax[i], psf.image_list[i], window=window, **kwargs)
return fig, ax
if window is None:
window = psf.window
if flipx:
ax.invert_xaxis()
# cut out the requested window
psf = psf[window]
# Evaluate the model image
X, Y = psf.get_coordinate_corner_meshgrid()
X = X.detach().cpu().numpy()
Y = Y.detach().cpu().numpy()
psf = psf.data.detach().cpu().numpy()
# Default kwargs for image
imshow_kwargs = {
"cmap": cmap_grad,
"norm": matplotlib.colors.LogNorm(), # "norm": ImageNormalize(stretch=LogStretch(), clip=False),
}
# Update with user provided kwargs
imshow_kwargs.update(kwargs)
# if requested, convert the continuous colourmap into discrete levels
if cmap_levels is not None:
imshow_kwargs["cmap"] = matplotlib.colors.ListedColormap(
list(imshow_kwargs["cmap"](c) for c in np.linspace(0.0, 1.0, cmap_levels))
)
# Plot the image
im = ax.pcolormesh(X, Y, psf, **imshow_kwargs)
# Enforce equal spacing on x y
ax.axis("equal")
ax.set_xlabel("PSF X [arcsec]")
ax.set_ylabel("PSF Y [arcsec]")
return fig, ax
[docs]
@torch.no_grad()
def model_image(
fig,
ax,
model,
sample_image=None,
window=None,
target=None,
showcbar=True,
target_mask=False,
cmap_levels=None,
flipx=False,
magunits=True,
sample_full_image=False,
**kwargs,
):
"""
This function is used to generate a model image and display it using the provided figure and axes.
Args:
fig (matplotlib.figure.Figure): The figure object in which the image will be displayed.
ax (matplotlib.axes.Axes): The axes object on which the image will be plotted.
model (Model): The model object used to generate a model image if `sample_image` is not provided.
sample_image (Image or Image_List, optional): The image or list of images to be displayed.
If `None`, a model image is generated using the provided `model`. Defaults to `None`.
window (Window, optional): The window through which the image is viewed. If `None`, the window of the
provided `model` is used. Defaults to `None`.
target (Target, optional): The target or list of targets for the image or image list.
If `None`, the target of the `model` is used. Defaults to `None`.
showcbar (bool, optional): Whether to show the color bar. Defaults to `True`.
target_mask (bool, optional): Whether to apply the mask of the target. If `True` and if the target has a mask,
the mask is applied to the image. Defaults to `False`.
cmap_levels (int, optional): The number of discrete levels to convert the continuous color map to.
If not `None`, the color map is converted to a ListedColormap with the specified number of levels.
Defaults to `None`.
sample_full_image: If True, every model will be sampled on the full image window. If False (default) each model will only be sampled in its fitting window.
**kwargs: Arbitrary keyword arguments. These are used to override the default imshow_kwargs.
Returns:
fig (matplotlib.figure.Figure): The figure object containing the displayed image.
ax (matplotlib.axes.Axes): The axes object containing the displayed image.
Note:
If the `sample_image` is an `Image_List`, this function will recursively call itself for each image in the list,
with the corresponding target and window. The `showcbar` parameter and `kwargs` are passed unchanged to each recursive call.
"""
if sample_image is None:
if sample_full_image:
sample_image = model.make_model_image()
sample_image = model(sample_image)
else:
sample_image = model()
# Use model target if not given
if target is None:
target = model.target
# Use model window if not given
if window is None:
window = model.window
# Handle image lists
if isinstance(sample_image, Image_List):
for i, images in enumerate(zip(sample_image, target, window)):
model_image(
fig,
ax[i],
model,
sample_image=images[0],
window=images[2],
target=images[1],
showcbar=showcbar,
target_mask=target_mask,
cmap_levels=cmap_levels,
flipx=flipx,
magunits=magunits,
**kwargs,
)
return fig, ax
if flipx:
ax.invert_xaxis()
# cut out the requested window
sample_image = sample_image[window]
# Evaluate the model image
X, Y = sample_image.get_coordinate_corner_meshgrid()
X = X.detach().cpu().numpy()
Y = Y.detach().cpu().numpy()
sample_image = sample_image.data.detach().cpu().numpy()
# Default kwargs for image
imshow_kwargs = {
"cmap": cmap_grad,
"norm": matplotlib.colors.LogNorm(), # "norm": ImageNormalize(stretch=LogStretch(), clip=False),
}
# Update with user provided kwargs
imshow_kwargs.update(kwargs)
# if requested, convert the continuous colourmap into discrete levels
if cmap_levels is not None:
imshow_kwargs["cmap"] = matplotlib.colors.ListedColormap(
list(imshow_kwargs["cmap"](c) for c in np.linspace(0.0, 1.0, cmap_levels))
)
# If zeropoint is available, convert to surface brightness units
if target.zeropoint is not None and magunits:
sample_image = flux_to_sb(sample_image, target.pixel_area.item(), target.zeropoint.item())
del imshow_kwargs["norm"]
imshow_kwargs["cmap"] = imshow_kwargs["cmap"].reversed()
# Apply the mask if available
if target_mask and target.has_mask:
sample_image[target.mask.detach().cpu().numpy()] = np.nan
# Plot the image
im = ax.pcolormesh(X, Y, sample_image, **imshow_kwargs)
# Enforce equal spacing on x y
ax.axis("equal")
ax.set_xlabel("Tangent Plane X [arcsec]")
ax.set_ylabel("Tangent Plane Y [arcsec]")
# Add a colourbar
if showcbar:
if target.zeropoint is not None and magunits:
clb = fig.colorbar(im, ax=ax, label="Surface Brightness [mag/arcsec$^2$]")
clb.ax.invert_yaxis()
else:
clb = fig.colorbar(im, ax=ax, label="log$_{10}$(flux)")
return fig, ax
[docs]
@torch.no_grad()
def residual_image(
fig,
ax,
model,
target=None,
sample_image=None,
showcbar=True,
window=None,
center_residuals=False,
clb_label=None,
normalize_residuals=False,
flipx=False,
sample_full_image=False,
**kwargs,
):
"""
This function is used to calculate and display the residuals of a model image with respect to a target image.
The residuals are calculated as the difference between the target image and the sample image.
Args:
fig (matplotlib.figure.Figure): The figure object in which the residuals will be displayed.
ax (matplotlib.axes.Axes): The axes object on which the residuals will be plotted.
model (Model): The model object used to generate a model image if `sample_image` is not provided.
target (Target or Image_List, optional): The target or list of targets for the image or image list.
If `None`, the target of the `model` is used. Defaults to `None`.
sample_image (Image or Image_List, optional): The image or list of images from which residuals will be calculated.
If `None`, a model image is generated using the provided `model`. Defaults to `None`.
showcbar (bool, optional): Whether to show the color bar. Defaults to `True`.
window (Window or Window_List, optional): The window through which the image is viewed. If `None`, the window of the
provided `model` is used. Defaults to `None`.
center_residuals (bool, optional): Whether to subtract the median of the residuals. If `True`, the median is subtracted
from the residuals. Defaults to `False`.
clb_label (str, optional): The label for the colorbar. If `None`, a default label is used based on the normalization of the
residuals. Defaults to `None`.
normalize_residuals (bool, optional): Whether to normalize the residuals. If `True`, residuals are divided by the square root
of the variance of the target. Defaults to `False`.
sample_full_image: If True, every model will be sampled on the full image window. If False (default) each model will only be sampled in its fitting window.
**kwargs: Arbitrary keyword arguments. These are used to override the default imshow_kwargs.
Returns:
fig (matplotlib.figure.Figure): The figure object containing the displayed residuals.
ax (matplotlib.axes.Axes): The axes object containing the displayed residuals.
Note:
If the `window`, `target`, or `sample_image` are lists, this function will recursively call itself for each element in the list,
with the corresponding window, target, and sample image. The `showcbar`, `center_residuals`, and `kwargs` are passed unchanged to
each recursive call.
"""
if window is None:
window = model.window
if target is None:
target = model.target
if sample_image is None:
if sample_full_image:
sample_image = model.make_model_image()
sample_image = model(sample_image)
else:
sample_image = model()
if isinstance(window, Window_List) or isinstance(target, Image_List):
for i_ax, win, tar, sam in zip(ax, window, target, sample_image):
residual_image(
fig,
i_ax,
model,
target=tar,
sample_image=sam,
window=win,
showcbar=showcbar,
center_residuals=center_residuals,
clb_label=clb_label,
normalize_residuals=normalize_residuals,
flipx=flipx,
**kwargs,
)
return fig, ax
if flipx:
ax.invert_xaxis()
X, Y = sample_image[window].get_coordinate_corner_meshgrid()
X = X.detach().cpu().numpy()
Y = Y.detach().cpu().numpy()
residuals = (target[window] - sample_image[window]).data
if isinstance(normalize_residuals, bool) and normalize_residuals:
residuals = residuals / torch.sqrt(target[window].variance)
elif isinstance(normalize_residuals, torch.Tensor):
residuals = residuals / torch.sqrt(normalize_residuals)
normalize_residuals = True
residuals = residuals.detach().cpu().numpy()
if target.has_mask:
residuals[target[window].mask.detach().cpu().numpy()] = np.nan
if center_residuals:
residuals -= np.nanmedian(residuals)
residuals = np.arctan(residuals / (iqr(residuals[np.isfinite(residuals)], rng=[10, 90]) * 2))
extreme = np.max(np.abs(residuals[np.isfinite(residuals)]))
imshow_kwargs = {
"cmap": cmap_div,
"vmin": -extreme,
"vmax": extreme,
}
imshow_kwargs.update(kwargs)
im = ax.pcolormesh(X, Y, residuals, **imshow_kwargs)
ax.axis("equal")
ax.set_xlabel("Tangent Plane X [arcsec]")
ax.set_ylabel("Tangent Plane Y [arcsec]")
if showcbar:
if normalize_residuals:
default_label = f"tan$^{{-1}}$((Target - {model.name}) / $\\sigma$)"
else:
default_label = f"tan$^{{-1}}$(Target - {model.name})"
clb = fig.colorbar(im, ax=ax, label=default_label if clb_label is None else clb_label)
clb.ax.set_yticks([])
clb.ax.set_yticklabels([])
return fig, ax
[docs]
def model_window(fig, ax, model, target=None, rectangle_linewidth=2, **kwargs):
if isinstance(ax, np.ndarray):
for i, axitem in enumerate(ax):
model_window(fig, axitem, model, target=model.target.image_list[i], **kwargs)
return fig, ax
if isinstance(model, Group_Model):
for m in model.models.values():
if isinstance(m.window, Window_List):
use_window = m.window.window_list[m.target.index(target)]
else:
use_window = m.window
lowright = use_window.pixel_shape.clone().to(dtype=AP_config.ap_dtype)
lowright[1] = 0.0
lowright = use_window.origin + use_window.pixel_to_plane_delta(lowright)
lowright = lowright.detach().cpu().numpy()
upleft = use_window.pixel_shape.clone().to(dtype=AP_config.ap_dtype)
upleft[0] = 0.0
upleft = use_window.origin + use_window.pixel_to_plane_delta(upleft)
upleft = upleft.detach().cpu().numpy()
end = use_window.origin + use_window.end
end = end.detach().cpu().numpy()
x = [
use_window.origin[0].detach().cpu().numpy(),
lowright[0],
end[0],
upleft[0],
]
y = [
use_window.origin[1].detach().cpu().numpy(),
lowright[1],
end[1],
upleft[1],
]
ax.add_patch(
Polygon(
xy=list(zip(x, y)),
fill=False,
linewidth=rectangle_linewidth,
edgecolor=main_pallet["secondary1"],
)
)
else:
if isinstance(model.window, Window_List):
use_window = model.window.window_list[model.target.index(target)]
else:
use_window = model.window
lowright = use_window.pixel_shape.clone().to(dtype=AP_config.ap_dtype)
lowright[1] = 0.0
lowright = use_window.origin + use_window.pixel_to_plane_delta(lowright)
lowright = lowright.detach().cpu().numpy()
upleft = use_window.pixel_shape.clone().to(dtype=AP_config.ap_dtype)
upleft[0] = 0.0
upleft = use_window.origin + use_window.pixel_to_plane_delta(upleft)
upleft = upleft.detach().cpu().numpy()
end = use_window.origin + use_window.end
end = end.detach().cpu().numpy()
x = [
use_window.origin[0].detach().cpu().numpy(),
lowright[0],
end[0],
upleft[0],
]
y = [
use_window.origin[1].detach().cpu().numpy(),
lowright[1],
end[1],
upleft[1],
]
ax.add_patch(
Polygon(
xy=list(zip(x, y)),
fill=False,
linewidth=rectangle_linewidth,
edgecolor=main_pallet["secondary1"],
)
)
return fig, ax