Source code for astrophot.plots.image

from typing import Literal, Optional, Union
import numpy as np
import torch

from astropy.visualization import HistEqStretch, ImageNormalize
from matplotlib.patches import Polygon
import matplotlib
from scipy.stats import iqr

from ..models import GroupModel, PSFModel, PSFGroupModel
from ..image import ImageList, WindowList, PSFImage
from .. import config
from ..backend_obj import backend
from ..utils.conversions.units import flux_to_sb
from ..utils.decorators import ignore_numpy_warnings
from .visuals import *

__all__ = ("target_image", "psf_image", "model_image", "residual_image", "model_window")


[docs] @ignore_numpy_warnings def target_image(fig, ax, target, window=None, **kwargs): """ This function is used to display a target image using the provided figure and axes. The target is plotted using histogram equalization for better visibility of the image data for the faint areas of the image, while it uses log scale normalization for the bright areas. :param fig: The figure object in which the target image will be displayed. :type fig: matplotlib.figure.Figure :param ax: The axes object on which the target image will be plotted. :type ax: matplotlib.axes.Axes :param target: The image or list of images to be displayed. :type target: Image or Image_List :param window: The window through which the image is viewed. If `None`, the window of the provided `target` is used. Defaults to `None`. :type window: Window, optional :param kwargs: Arbitrary keyword arguments. Note: If the `target` is an `Image_List`, this function will recursively call itself for each image in the list. The `window` parameter and `kwargs` are passed unchanged to each recursive call. """ # recursive call for target image list if isinstance(target, ImageList): for i in range(len(target.images)): target_image(fig, ax[i], target.images[i], window=window, **kwargs) return fig, ax if window is None: window = target.window target_area = target[window] dat = np.copy(backend.to_numpy(target_area._data)) dat[backend.to_numpy(target_area._mask)] = np.nan X, Y = target_area.coordinate_corner_meshgrid() X = backend.to_numpy(X) Y = backend.to_numpy(Y) sky = np.nanmedian(dat) noise = iqr(dat[np.isfinite(dat)], rng=(16, 84)) / 2 if noise == 0: noise = np.nanstd(dat) if kwargs.get("linear", False): im = ax.pcolormesh( X, Y, dat, cmap=cmap_grad, ) else: im = ax.pcolormesh( X, Y, dat, cmap="gray_r", norm=ImageNormalize( stretch=HistEqStretch( dat[np.logical_and(dat <= (sky + 3 * noise), np.isfinite(dat))] ), clip=False, vmax=sky + 3 * noise, vmin=np.nanmin(dat), ), ) pickhist = dat < (sky + 3 * noise) if np.sum(~pickhist) > 5: # only draw log if multiple pixels above noise im = ax.pcolormesh( X, Y, np.ma.masked_where(pickhist, dat), cmap=cmap_grad, norm=matplotlib.colors.LogNorm(), clim=[sky + 3 * noise, None], ) if np.linalg.det(target.CD.npvalue) < 0: ax.invert_xaxis() ax.axis("equal") ax.set_xlabel("Tangent Plane X [arcsec]") ax.set_ylabel("Tangent Plane Y [arcsec]") return fig, ax
[docs] @torch.no_grad() @ignore_numpy_warnings def psf_image( fig, ax, psf: Union[PSFImage, PSFModel, PSFGroupModel], cmap_levels: Optional[int] = None, vmin: Optional[float] = None, vmax: Optional[float] = None, **kwargs, ): """For plotting PSF images, or the output of a PSF model. :param fig: The figure object in which the PSF image will be displayed. :type fig: matplotlib.figure.Figure :param ax: The axes object on which the PSF image will be plotted. :type ax: matplotlib.axes.Axes :param psf: The PSF model or group model to be displayed. :type psf: PSFImage or PSFModel or PSFGroupModel :param cmap_levels: The number of discrete levels to convert the continuous color map to. If not `None`, the color map is converted to a ListedColormap with the specified number of levels. Defaults to `None`. :type cmap_levels: int, optional :param vmin: The minimum value for the color scale. Defaults to `None`. :type vmin: float, optional :param vmax: The maximum value for the color scale. Defaults to `None`. :type vmax: float, optional """ if isinstance(psf, (PSFModel, PSFGroupModel)): psf = psf() # recursive call for target image list if isinstance(psf, ImageList): for i in range(len(psf.images)): psf_image(fig, ax[i], psf.images[i], **kwargs) return fig, ax # Evaluate the model image i, j = psf.pixel_corner_meshgrid() i = backend.to_numpy(i) j = backend.to_numpy(j) psf = backend.to_numpy(psf._data) # Default kwargs for image kwargs = { "cmap": cmap_grad, "norm": matplotlib.colors.LogNorm( vmin=vmin, vmax=vmax ), # "norm": ImageNormalize(stretch=LogStretch(), clip=False), **kwargs, } # if requested, convert the continuous colourmap into discrete levels if cmap_levels is not None: kwargs["cmap"] = matplotlib.colors.ListedColormap( list(kwargs["cmap"](c) for c in np.linspace(0.0, 1.0, cmap_levels)) ) # Plot the image ax.pcolormesh(i, j, psf, **kwargs) # Enforce equal spacing on x y ax.axis("equal") ax.set_xlabel("PSF I [pix]") ax.set_ylabel("PSF J [pix]") return fig, ax
[docs] @torch.no_grad() @ignore_numpy_warnings def model_image( fig, ax, model, sample_image=None, window=None, target=None, showcbar: bool = True, target_mask: bool = False, cmap_levels: Optional[int] = None, magunits: bool = True, vmin: Optional[float] = None, vmax: Optional[float] = None, **kwargs, ): """ This function is used to generate a model image and display it using the provided figure and axes. :param fig: The figure object in which the image will be displayed. :type fig: matplotlib.figure.Figure :param ax: The axes object on which the image will be plotted. :type ax: matplotlib.axes.Axes :param model: The model object used to generate a model image if `sample_image` is not provided. :type model: Model :param sample_image: The image or list of images to be displayed. If `None`, a model image is generated using the provided `model`. Defaults to `None`. :type sample_image: Image or Image_List, optional :param window: The window through which the image is viewed. If `None`, the window of the provided `model` is used. Defaults to `None`. :type window: Window, optional :param target: The target or list of targets for the image or image list. If `None`, the target of the `model` is used. Defaults to `None`. :type target: Target, optional :param showcbar: Whether to show the color bar. Defaults to `True`. :type showcbar: bool, optional :param target_mask: Whether to apply the mask of the target. If `True` and if the target has a mask, the mask is applied to the image. Defaults to `False`. :type target_mask: bool, optional :param cmap_levels: The number of discrete levels to convert the continuous color map to. If not `None`, the color map is converted to a ListedColormap with the specified number of levels. Defaults to `None`. :type cmap_levels: int, optional :param magunits: Whether to convert the image to surface brightness units. If `True`, the zeropoint of the target is used to convert the image to surface brightness units. Defaults to `True`. :type magunits: bool, optional :param vmin: The minimum value for the color scale. Defaults to `None`. :type vmin: float, optional :param vmax: The maximum value for the color scale. Defaults to `None`. :type vmax: float, optional :param kwargs: Arbitrary keyword arguments. These are used to override the default imshow_kwargs. Note: If the `sample_image` is an `Image_List`, this function will recursively call itself for each image in the list, with the corresponding target and window. The `showcbar` parameter and `kwargs` are passed unchanged to each recursive call. """ if sample_image is None: sample_image = model() # Use model target if not given if target is None: target = model.target # Use model window if not given if window is None: window = model.window # Handle image lists if isinstance(sample_image, ImageList): for i, (images, targets, windows) in enumerate(zip(sample_image, target, window)): model_image( fig, ax[i], model, sample_image=images, window=windows, target=targets, showcbar=showcbar, target_mask=target_mask, cmap_levels=cmap_levels, magunits=magunits, vmin=vmin, vmax=vmax, **kwargs, ) return fig, ax # cut out the requested window sample_image = sample_image[window] # Evaluate the model image X, Y = sample_image.coordinate_corner_meshgrid() X = backend.to_numpy(X) Y = backend.to_numpy(Y) sample_image = backend.to_numpy(sample_image._data) # Default kwargs for image kwargs = { "cmap": cmap_grad, **kwargs, } # if requested, convert the continuous colourmap into discrete levels if cmap_levels is not None: kwargs["cmap"] = matplotlib.colors.ListedColormap( list(kwargs["cmap"](c) for c in np.linspace(0.0, 1.0, cmap_levels)) ) # If zeropoint is available, convert to surface brightness units if target.zeropoint is not None and magunits: sample_image = flux_to_sb(sample_image, target.pixel_area.item(), target.zeropoint.item()) kwargs["cmap"] = kwargs["cmap"].reversed() kwargs["vmin"] = vmin kwargs["vmax"] = vmax else: kwargs = { "norm": matplotlib.colors.LogNorm( vmin=vmin, vmax=vmax ), # "norm": ImageNormalize(stretch=LogStretch(), clip=False), **kwargs, } # Apply the mask if available sample_image[backend.to_numpy(target[window]._mask)] = np.nan # Plot the image im = ax.pcolormesh(X, Y, sample_image, **kwargs) if np.linalg.det(target.CD.npvalue) < 0: ax.invert_xaxis() # Enforce equal spacing on x y ax.axis("equal") ax.set_xlabel("Tangent Plane X [arcsec]") ax.set_ylabel("Tangent Plane Y [arcsec]") # Add a colourbar if showcbar: if target.zeropoint is not None and magunits: clb = fig.colorbar(im, ax=ax, label="Surface Brightness [mag/arcsec$^2$]") clb.ax.invert_yaxis() else: clb = fig.colorbar(im, ax=ax, label="log$_{10}$(flux)") return fig, ax
[docs] @torch.no_grad() @ignore_numpy_warnings def residual_image( fig, ax, model, target=None, sample_image=None, showcbar=True, window=None, clb_label=None, normalize_residuals=False, scaling: Literal["arctan", "clip", "none"] = "arctan", **kwargs, ): """ This function is used to calculate and display the residuals of a model image with respect to a target image. The residuals are calculated as the difference between the target image and the sample image and may be normalized by the standard deviation. :param fig: The figure object in which the residuals will be displayed. :type fig: matplotlib.figure.Figure :param ax: The axes object on which the residuals will be plotted. :type ax: matplotlib.axes.Axes :param model: The model object used to generate a model image if `sample_image` is not provided. :type model: Model :param target: The target or list of targets for the image or image list. If `None`, the target of the `model` is used. Defaults to `None`. :type target: Target or Image_List, optional :param sample_image: The image or list of images from which residuals will be calculated. If `None`, a model image is generated using the provided `model`. Defaults to `None`. :type sample_image: Image or Image_List, optional :param showcbar: Whether to show the color bar. Defaults to `True`. :type showcbar: bool, optional :param window: The window through which the image is viewed. If `None`, the window of the provided `model` is used. Defaults to `None`. :type window: Window or Window_List, optional :param clb_label: The label for the colorbar. If `None`, a default label is used based on the normalization of the residuals. Defaults to `None`. :type clb_label: str, optional :param normalize_residuals: Whether to normalize the residuals. If `True`, residuals are divided by the square root of the variance of the target. Defaults to `False`. :type normalize_residuals: bool, optional :param scaling: The scaling method for the residuals. Options are "arctan", "clip", or "none". arctan will show all residuals, though squish high values to make the fainter residuals more visible, clip will show the residuals in linear space but remove any values above/below 5 sigma, none does no scaling and simply shows the residuals in linear space. Defaults to "arctan". :type scaling: str, optional :param kwargs: Arbitrary keyword arguments. These are used to override the default imshow_kwargs. Note: If the `window`, `target`, or `sample_image` are lists, this function will recursively call itself for each element in the list, with the corresponding window, target, and sample image. The `showcbar`, `center_residuals`, and `kwargs` are passed unchanged to each recursive call. """ if window is None: window = model.window if target is None: target = model.target if sample_image is None: sample_image = model() if isinstance(window, WindowList) or isinstance(target, ImageList): for i_ax, win, tar, sam in zip(ax, window, target, sample_image): residual_image( fig, i_ax, model, target=tar, sample_image=sam, window=win, showcbar=showcbar, clb_label=clb_label, normalize_residuals=normalize_residuals, scaling=scaling, **kwargs, ) return fig, ax sample_image = sample_image[window] target = target[window] X, Y = sample_image.coordinate_corner_meshgrid() X = backend.to_numpy(X) Y = backend.to_numpy(Y) residuals = (target - sample_image)._data if normalize_residuals is True: residuals = residuals / backend.sqrt(target._variance) elif isinstance(normalize_residuals, backend.array_type): residuals = residuals / backend.sqrt(normalize_residuals) normalize_residuals = True residuals = backend.to_numpy(residuals) residuals[backend.to_numpy(target._mask)] = np.nan if scaling == "clip": if normalize_residuals is not True: config.logger.warning( "Using clipping scaling without normalizing residuals. This may lead to confusing results." ) residuals = np.clip(residuals, -5, 5) vmax = 5 default_label = ( f"(Target - {model.name}) / $\\sigma$" if normalize_residuals else f"(Target - {model.name})" ) elif scaling == "arctan": residuals = np.arctan( residuals / (iqr(residuals[np.isfinite(residuals)], rng=[10, 90]) * 2) ) vmax = np.pi / 2 if normalize_residuals: default_label = f"tan$^{{-1}}$((Target - {model.name}) / $\\sigma$)" else: default_label = f"tan$^{{-1}}$(Target - {model.name})" elif scaling == "none": vmax = np.max(np.abs(residuals[np.isfinite(residuals)])) default_label = ( f"(Target - {model.name}) / $\\sigma$" if normalize_residuals else f"(Target - {model.name})" ) else: raise ValueError(f"Unknown scaling type {scaling}. Use 'clip', 'arctan', or 'none'.") imshow_kwargs = { "cmap": cmap_div, "vmin": -vmax, "vmax": vmax, } imshow_kwargs.update(kwargs) im = ax.pcolormesh(X, Y, residuals, **imshow_kwargs) if target.flip_ra_axis: ax.invert_xaxis() ax.axis("equal") ax.set_xlabel("Tangent Plane X [arcsec]") ax.set_ylabel("Tangent Plane Y [arcsec]") if showcbar: clb = fig.colorbar(im, ax=ax, label=default_label if clb_label is None else clb_label) clb.ax.set_yticks([]) clb.ax.set_yticklabels([]) return fig, ax
[docs] @ignore_numpy_warnings def model_window(fig, ax, model, target=None, rectangle_linewidth=2, **kwargs): """Used for plotting the window(s) of a model on a target image. These windows bound the region that a model will be evaluated/fit to. :param fig: The figure object in which the model window will be displayed. :type fig: matplotlib.figure.Figure :param ax: The axes object on which the model window will be plotted. :type ax: matplotlib.axes.Axes :param model: The model object whose window will be displayed. :type model: Model :param target: The target or list of targets for the image or image list. If `None`, the target of the `model` is used. Defaults to `None`. :type target: Target or Image_List, optional :param rectangle_linewidth: The linewidth of the rectangle drawn around the model window. Defaults to 2. :type rectangle_linewidth: int, optional :param kwargs: Arbitrary keyword arguments. These are used to override the default rectangle properties. """ if target is None: target = model.target if isinstance(ax, np.ndarray): for i, axitem in enumerate(ax): model_window(fig, axitem, model, target=target.images[i], **kwargs) return fig, ax if isinstance(model, GroupModel): for m in model.models: if isinstance(m.window, WindowList): use_window = m.window.windows[m.target.index(target)] else: use_window = m.window corners = target[use_window].corners() x = [ corners[0][0].item(), corners[1][0].item(), corners[2][0].item(), corners[3][0].item(), ] y = [ corners[0][1].item(), corners[1][1].item(), corners[2][1].item(), corners[3][1].item(), ] ax.add_patch( Polygon( xy=list(zip(x, y)), fill=False, linewidth=rectangle_linewidth, edgecolor=main_pallet["secondary1"], **kwargs, ) ) else: use_window = model.window corners = target[use_window].corners() x = [ corners[0][0].item(), corners[1][0].item(), corners[2][0].item(), corners[3][0].item(), ] y = [ corners[0][1].item(), corners[1][1].item(), corners[2][1].item(), corners[3][1].item(), ] ax.add_patch( Polygon( xy=list(zip(x, y)), fill=False, linewidth=rectangle_linewidth, edgecolor=main_pallet["secondary1"], **kwargs, ) ) return fig, ax