from typing import Optional, Tuple, Union
import numpy as np
from astropy.wcs import WCS as AstropyWCS
from astropy.io import fits
from ..param import Module, Param, forward
from .. import config
from ..backend_obj import backend, ArrayLike
from ..utils.conversions.units import deg_to_arcsec, arcsec_to_deg
from .window import Window, WindowList, WindowBatch
from ..errors import InvalidImage, SpecificationConflict
# from .base import BaseImage
from . import func
__all__ = ["Image", "ImageList"]
[docs]
class Image(Module):
"""Core class to represent images with pixel values, pixel scale,
and a window defining the spatial coordinates on the sky. It supports
arithmetic operations with other image objects while preserving logical
image boundaries. It also provides methods for determining the coordinate
locations of pixels
:param crval: The reference coordinate of the image in degrees [RA, DEC]. [model param]
:param crtan: The tangent plane coordinate of the image in arcseconds [x, y]. [model param]
:param CD: The coordinate transformation matrix in arcseconds/pixel. [model param]
:param crpix: The reference pixel coordinates in the image, which is used to convert from pixel coordinates to tangent plane coordinates. This is not a model param and is fixed for a given image.
:param data: The image data as a Array of pixel values. If not provided, a Array of zeros will be created.
:param zeropoint: The zeropoint of the image, which is used to convert from pixel flux to magnitude.
:param pixelscale: The side length of a pixel, used to create a simple diagonal CD matrix.
:param wcs: An optional Astropy WCS object to initialize the image.
:param filename: The filename to load the image from. If provided, the image will be loaded from the file.
:param hduext: The HDU extension to load from the FITS file specified in `filename`.
:param identity: An optional identity string for the image (mostly used internally).
"""
expect_ctype = (("RA---TAN",), ("DEC--TAN",))
base_scale = 1.0
def __init__(
self,
*,
data: Optional[ArrayLike] = None,
CD: Optional[Union[float, ArrayLike]] = None,
zeropoint: Optional[Union[float, ArrayLike]] = None,
crpix: Union[ArrayLike, tuple] = (0.0, 0.0),
crtan: Union[ArrayLike, tuple] = (0.0, 0.0),
crval: Union[ArrayLike, tuple] = (0.0, 0.0),
pixelscale: Optional[Union[ArrayLike, float]] = 1.0,
wcs: Optional[AstropyWCS] = None,
filename: Optional[str] = None,
hduext: int = 0,
identity: str = None,
name: Optional[str] = None,
_data: Optional[ArrayLike] = None,
):
super().__init__(name=name)
if _data is None:
self.data = data # units: flux
else:
self._data = _data
self.crtan = Param(
"crtan",
crtan,
shape=(2,),
units="arcsec",
dtype=config.DTYPE,
device=config.DEVICE,
)
self.zeropoint = zeropoint
if identity is None:
self._identity = id(self)
else:
self._identity = identity
if wcs is not None:
if wcs.wcs.ctype[0] not in self.expect_ctype[0]:
config.logger.warning(
"Astropy WCS not tangent plane coordinate system! May not be compatible with AstroPhot."
)
if wcs.wcs.ctype[1] not in self.expect_ctype[1]:
config.logger.warning(
"Astropy WCS not tangent plane coordinate system! May not be compatible with AstroPhot."
)
crval = wcs.wcs.crval
crpix = np.array(wcs.wcs.crpix)[::-1] - 1 # handle FITS 1-indexing
if CD is not None:
config.logger.warning("WCS CD set with supplied WCS, ignoring user supplied CD!")
CD = deg_to_arcsec * wcs.pixel_scale_matrix
# set the data
self.crval = Param(
"crval", crval, shape=(2,), units="deg", dtype=config.DTYPE, device=config.DEVICE
)
self.crpix = crpix
if isinstance(CD, (float, int)):
CD = np.array([[CD, 0.0], [0.0, CD]], dtype=np.float64)
elif CD is None:
CD = np.array([[pixelscale, 0.0], [0.0, pixelscale]], dtype=np.float64)
self.CD = Param(
"CD",
CD,
shape=(2, 2),
units="arcsec/pixel",
dtype=config.DTYPE,
device=config.DEVICE,
)
if filename is not None:
self.load(filename, hduext=hduext)
return
@property
def identity(self):
return self._identity
@property
def data(self):
"""The image data, which is a Array of pixel values."""
return backend.transpose(self._data, 1, 0)
@data.setter
def data(self, value: Optional[ArrayLike]):
"""Set the image data. If value is None, the data is initialized to an empty Array."""
if value is None:
self._data = backend.empty((0, 0), dtype=config.DTYPE, device=config.DEVICE)
else:
# Transpose since pytorch uses (j, i) indexing when (i, j) is more natural for coordinates
self._data = backend.transpose(
backend.as_array(value, dtype=config.DTYPE, device=config.DEVICE), 1, 0
)
@property
def crpix(self) -> ArrayLike:
"""The reference pixel coordinates in the image, which is used to convert from pixel coordinates to tangent plane coordinates."""
return self._crpix
@crpix.setter
def crpix(self, value: Union[ArrayLike, tuple]):
self._crpix = np.array(value, dtype=np.float64)
@property
def zeropoint(self) -> ArrayLike:
"""The zeropoint of the image, which is used to convert from pixel flux to magnitude."""
return self._zeropoint
@zeropoint.setter
def zeropoint(self, value):
"""Set the zeropoint of the image."""
if value is None:
self._zeropoint = None
else:
self._zeropoint = backend.as_array(value, dtype=config.DTYPE, device=config.DEVICE)
@property
def window(self) -> Window:
return Window(window=((0, 0), self._data.shape[:2]), image=self)
@property
def center(self):
shape = backend.as_array(self._data.shape[:2], dtype=config.DTYPE, device=config.DEVICE)
return backend.stack(self.pixel_to_plane(*((shape - 1) / 2)))
# @property
# def shape(self):
# """The shape of the image data."""
# return self.data.shape
@property
@forward
def pixel_area(self, CD):
"""The area inside a pixel in arcsec^2"""
return backend.abs(backend.linalg.det(CD))
@property
@forward
def pixelscale(self):
"""The approximate side length of a pixel, which is just
sqrt(pixel_area). For square pixels this is the actual pixel
length, for rectangular pixels it is a kind of average.
The pixelscale is not used for exact calculations
and instead sets a size scale within an image.
"""
return backend.sqrt(self.pixel_area)
[docs]
@forward
def pixel_collecting_area(self, I_, J_, upsample, CD):
"""The area of the sky that each pixel collects light from, in arcsec^2.
This is just the pixel area, but can be overridden for certain types of
images (e.g. SIP images) where the pixel collecting area is not the same
as the pixel area."""
return backend.abs(backend.linalg.det(CD)) / upsample**2
@property
def flip_ra_axis(self):
return np.linalg.det(self.CD.npvalue) < 0
[docs]
@forward
def pixel_to_plane(
self,
i: ArrayLike,
j: ArrayLike,
crtan: ArrayLike,
CD: ArrayLike,
_crpix: Optional[ArrayLike] = None,
) -> Tuple[ArrayLike, ArrayLike]:
crpix = self.crpix if _crpix is None else _crpix
return func.pixel_to_plane_linear(i, j, *crpix, CD, *crtan)
[docs]
@forward
def plane_to_pixel(
self,
x: ArrayLike,
y: ArrayLike,
crtan: ArrayLike,
CD: ArrayLike,
_crpix: Optional[ArrayLike] = None,
) -> Tuple[ArrayLike, ArrayLike]:
crpix = self.crpix if _crpix is None else _crpix
return func.plane_to_pixel_linear(x, y, *crpix, CD, *crtan)
[docs]
@forward
def plane_to_world(
self, x: ArrayLike, y: ArrayLike, crval: ArrayLike
) -> Tuple[ArrayLike, ArrayLike]:
return func.plane_to_world_gnomonic(x, y, *crval)
[docs]
@forward
def world_to_plane(
self, ra: ArrayLike, dec: ArrayLike, crval: ArrayLike
) -> Tuple[ArrayLike, ArrayLike]:
return func.world_to_plane_gnomonic(ra, dec, *crval)
[docs]
@forward
def world_to_pixel(self, ra: ArrayLike, dec: ArrayLike) -> Tuple[ArrayLike, ArrayLike]:
"""A wrapper which applies :meth:`world_to_plane` then
:meth:`plane_to_pixel`, see those methods for further
information.
"""
return self.plane_to_pixel(*self.world_to_plane(ra, dec))
[docs]
@forward
def pixel_to_world(self, i: ArrayLike, j: ArrayLike) -> Tuple[ArrayLike, ArrayLike]:
"""A wrapper which applies :meth:`pixel_to_plane` then
:meth:`plane_to_world`, see those methods for further
information.
"""
return self.plane_to_world(*self.pixel_to_plane(i, j))
[docs]
def pixel_center_meshgrid(self, window=None, pad=0, upsample=1) -> Tuple[ArrayLike, ArrayLike]:
"""Get a meshgrid of pixel coordinates in the image, centered on the pixel grid."""
if window is None:
window = self.window
return func.pixel_center_meshgrid(window.extent, pad, upsample, config.DTYPE, config.DEVICE)
[docs]
def pixel_corner_meshgrid(self, window=None, pad=0, upsample=1) -> Tuple[ArrayLike, ArrayLike]:
"""Get a meshgrid of pixel coordinates in the image, with corners at the pixel grid."""
if window is None:
window = self.window
return func.pixel_corner_meshgrid(window.extent, pad, upsample, config.DTYPE, config.DEVICE)
[docs]
def pixel_simpsons_meshgrid(
self, window=None, pad=0, upsample=1
) -> Tuple[ArrayLike, ArrayLike]:
"""Get a meshgrid of pixel coordinates in the image, with Simpson's rule sampling."""
if window is None:
window = self.window
return func.pixel_simpsons_meshgrid(
window.extent, pad, upsample, config.DTYPE, config.DEVICE
)
[docs]
def pixel_quad_meshgrid(
self, window=None, pad=0, upsample=1, order=3
) -> Tuple[ArrayLike, ArrayLike]:
"""Get a meshgrid of pixel coordinates in the image, with quadrature sampling."""
if window is None:
window = self.window
return func.pixel_quad_meshgrid(
window.extent, pad, upsample, config.DTYPE, config.DEVICE, order=order
)
[docs]
@forward
def coordinate_center_meshgrid(self) -> Tuple[ArrayLike, ArrayLike]:
"""Get a meshgrid of coordinate locations in the image, centered on the pixel grid."""
i, j = self.pixel_center_meshgrid()
return self.pixel_to_plane(i, j)
[docs]
@forward
def coordinate_corner_meshgrid(self) -> Tuple[ArrayLike, ArrayLike]:
"""Get a meshgrid of coordinate locations in the image, with corners at the pixel grid."""
i, j = self.pixel_corner_meshgrid()
return self.pixel_to_plane(i, j)
[docs]
@forward
def coordinate_simpsons_meshgrid(self) -> Tuple[ArrayLike, ArrayLike]:
"""Get a meshgrid of coordinate locations in the image, with Simpson's rule sampling."""
i, j = self.pixel_simpsons_meshgrid()
return self.pixel_to_plane(i, j)
[docs]
@forward
def coordinate_quad_meshgrid(self, order=3) -> Tuple[ArrayLike, ArrayLike]:
"""Get a meshgrid of coordinate locations in the image, with quadrature sampling."""
i, j, _ = self.pixel_quad_meshgrid(order=order)
return self.pixel_to_plane(i, j)
[docs]
def copy_kwargs(self, **kwargs) -> dict:
kwargs = {
"_data": backend.copy(self._data),
"CD": self.CD.value,
"crpix": self.crpix,
"crval": self.crval.value,
"crtan": self.crtan.value,
"zeropoint": self.zeropoint,
"identity": self.identity,
"name": self.name,
**kwargs,
}
return kwargs
[docs]
def copy(self, **kwargs):
"""Produce a copy of this image with all of the same properties. This
can be used when one wishes to make temporary modifications to
an image and then will want the original again.
"""
return self.__class__(**self.copy_kwargs(**kwargs))
[docs]
def blank_copy(self, **kwargs):
"""Produces a blank copy of the image which has the same properties
except that its data is now filled with zeros.
"""
kwargs = {
"_data": backend.zeros_like(self._data),
**kwargs,
}
return self.copy(**kwargs)
[docs]
def crop(self, pixels: Union[int, Tuple[int, int], Tuple[int, int, int, int]], **kwargs):
"""Crop the image by the number of pixels given. This will crop
the image in all four directions by the number of pixels given.
given data shape (N, M) the new shape will be:
crop - int: crop the same number of pixels on all sides. new shape (N - 2*crop, M - 2*crop)
crop - (int, int): crop each dimension by the number of pixels given. new shape (N - 2*crop[1], M - 2*crop[0])
crop - (int, int, int, int): crop each side by the number of pixels given assuming (x low, x high, y low, y high). new shape (N - crop[2] - crop[3], M - crop[0] - crop[1])
"""
if np.all(np.array(pixels) == 0):
return self
if isinstance(pixels, int):
data = self._data[
pixels : self._data.shape[0] - pixels,
pixels : self._data.shape[1] - pixels,
]
crpix = self.crpix - pixels
elif len(pixels) == 1: # same crop in all dimension
crop = pixels if isinstance(pixels, int) else pixels[0]
data = self._data[
crop : self._data.shape[0] - crop,
crop : self._data.shape[1] - crop,
]
crpix = self.crpix - crop
elif len(pixels) == 2: # different crop in each dimension
data = self._data[
pixels[0] : self._data.shape[0] - pixels[0],
pixels[1] : self._data.shape[1] - pixels[1],
]
crpix = self.crpix - pixels
elif len(pixels) == 4: # different crop on all sides
data = self._data[
pixels[0] : self._data.shape[0] - pixels[1],
pixels[2] : self._data.shape[1] - pixels[3],
]
crpix = self.crpix - pixels[0::2]
else:
raise ValueError(
f"Invalid crop shape {pixels}, must be (int,), (int, int), or (int, int, int, int)!"
)
return self.copy(_data=data, crpix=crpix, **kwargs)
[docs]
def reduce(self, scale: int, **kwargs):
"""This operation will downsample an image by the factor given. If
scale = 2 then 2x2 blocks of pixels will be summed together to
form individual larger pixels. A new image object will be
returned with the appropriate pixelscale and data Array. Note
that the window does not change in this operation since the
pixels are condensed, but the pixel size is increased
correspondingly.
:param scale: The scale factor by which to reduce the image.
:type scale: int
"""
if not isinstance(scale, int) and not (
isinstance(scale, ArrayLike) and scale.dtype is backend.int32
):
raise SpecificationConflict(f"Reduce scale must be an integer! not {type(scale)}")
if scale == 1:
return self
MS = self._data.shape[0] // scale
NS = self._data.shape[1] // scale
data = self._data[: MS * scale, : NS * scale].reshape(MS, scale, NS, scale).sum(axis=(1, 3))
CD = self.CD.value * scale
crpix = (self.crpix + 0.5) / scale - 0.5
return self.copy(
_data=data,
CD=CD,
crpix=crpix,
**kwargs,
)
[docs]
def to(self, dtype=None, device=None):
if dtype is None:
dtype = config.DTYPE
if device is None:
device = config.DEVICE
super().to(dtype=dtype, device=device)
self._data = backend.to(self._data, dtype=dtype, device=device)
if self.zeropoint is not None:
self.zeropoint = backend.to(self.zeropoint, dtype=dtype, device=device)
return self
[docs]
def flatten(self, attribute: str = "data") -> ArrayLike:
return backend.flatten(getattr(self, attribute))
[docs]
def fits_info(self) -> dict:
return {
"CTYPE1": "RA---TAN",
"CTYPE2": "DEC--TAN",
"CRVAL1": self.crval.value[0].item(),
"CRVAL2": self.crval.value[1].item(),
"CRPIX1": self.crpix[0] + 1,
"CRPIX2": self.crpix[1] + 1,
"CRTAN1": self.crtan.value[0].item(),
"CRTAN2": self.crtan.value[1].item(),
"CD1_1": self.CD.value[0][0].item() * arcsec_to_deg,
"CD1_2": self.CD.value[0][1].item() * arcsec_to_deg,
"CD2_1": self.CD.value[1][0].item() * arcsec_to_deg,
"CD2_2": self.CD.value[1][1].item() * arcsec_to_deg,
"MAGZP": self.zeropoint.item() if self.zeropoint is not None else -999,
"IDNTY": self.identity,
}
[docs]
def fits_images(self):
return [
fits.PrimaryHDU(
backend.to_numpy(backend.transpose(self._data, 1, 0)),
header=fits.Header(self.fits_info()),
)
]
[docs]
def get_astropywcs(self, **kwargs):
kwargs = {
"NAXIS": 2,
"NAXIS1": self.shape[0].item(),
"NAXIS2": self.shape[1].item(),
**self.fits_info(),
**kwargs,
}
return AstropyWCS(kwargs)
[docs]
def save(self, filename: str):
hdulist = fits.HDUList(self.fits_images())
hdulist.writeto(filename, overwrite=True)
[docs]
def load(self, filename: Union[str, fits.HDUList], hduext: int = 0):
"""Load an image from a FITS file. This will load the primary HDU
and set the data, CD, crpix, crval, and crtan attributes
accordingly. If the WCS is not tangent plane, it will warn the user.
"""
if isinstance(filename, str):
hdulist = fits.open(filename)
else:
hdulist = filename
self.data = np.array(hdulist[hduext].data, dtype=np.float64)
self.CD = (
np.array(
(
(hdulist[hduext].header["CD1_1"], hdulist[hduext].header["CD1_2"]),
(hdulist[hduext].header["CD2_1"], hdulist[hduext].header["CD2_2"]),
),
dtype=np.float64,
)
* deg_to_arcsec
)
self.crpix = (hdulist[hduext].header["CRPIX1"] - 1, hdulist[hduext].header["CRPIX2"] - 1)
self.crval = (hdulist[hduext].header["CRVAL1"], hdulist[hduext].header["CRVAL2"])
if "CRTAN1" in hdulist[hduext].header and "CRTAN2" in hdulist[hduext].header:
self.crtan = (hdulist[hduext].header["CRTAN1"], hdulist[hduext].header["CRTAN2"])
if "MAGZP" in hdulist[hduext].header and hdulist[hduext].header["MAGZP"] > -998:
self.zeropoint = hdulist[hduext].header["MAGZP"]
self._identity = hdulist[hduext].header.get("IDNTY", str(id(self)))
return hdulist
[docs]
def corners(
self,
) -> Tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike]:
pixel_lowleft = backend.make_array((-0.5, -0.5), dtype=config.DTYPE, device=config.DEVICE)
pixel_lowright = backend.make_array(
(self._data.shape[0] - 0.5, -0.5), dtype=config.DTYPE, device=config.DEVICE
)
pixel_upleft = backend.make_array(
(-0.5, self._data.shape[1] - 0.5), dtype=config.DTYPE, device=config.DEVICE
)
pixel_upright = backend.make_array(
(self._data.shape[0] - 0.5, self._data.shape[1] - 0.5),
dtype=config.DTYPE,
device=config.DEVICE,
)
lowleft = self.pixel_to_plane(*pixel_lowleft)
lowright = self.pixel_to_plane(*pixel_lowright)
upleft = self.pixel_to_plane(*pixel_upleft)
upright = self.pixel_to_plane(*pixel_upright)
return (lowleft, lowright, upright, upleft)
[docs]
def get_indices(self, other: Window):
if other.image is self:
return slice(max(0, other.i_low), min(self._data.shape[0], other.i_high)), slice(
max(0, other.j_low), min(self._data.shape[1], other.j_high)
)
if other.image.identity != self.identity:
config.logger.warning(
f"Attempting to match windows with different images! Window image: {other.image.name}, {other.image.identity}, self image: {self.name}, {self.identity}. This may fail unless you are sure the two images are on the same pixel grid."
)
shift = np.round(self.crpix - other.crpix).astype(int)
return slice(
min(max(0, other.i_low + shift[0]), self._data.shape[0]),
max(0, min(other.i_high + shift[0], self._data.shape[0])),
), slice(
min(max(0, other.j_low + shift[1]), self._data.shape[1]),
max(0, min(other.j_high + shift[1], self._data.shape[1])),
)
[docs]
def get_other_indices(self, other: Window):
if other.image == self: # fixme check identity, or check "is"?
shape = other.shape
return slice(
max(0, -other.i_low), min(self._data.shape[0] - other.i_low, shape[0])
), slice(max(0, -other.j_low), min(self._data.shape[1] - other.j_low, shape[1]))
raise ValueError()
[docs]
def get_window(self, other: Union[Window, "Image"], indices=None, **kwargs):
"""Get a new image object which is a window of this image
corresponding to the other image's window. This will return a
new image object with the same properties as this one, but with
the data cropped to the other image's window.
"""
if indices is None:
indices = self.get_indices(other if isinstance(other, Window) else other.window)
new_img = self.copy(
_data=self._data[indices],
crpix=self.crpix - np.array((indices[0].start, indices[1].start)),
**kwargs,
)
return new_img
def __sub__(self, other):
if isinstance(other, Image):
new_img = self[other]
new_img._data = new_img._data - other[self]._data
return new_img
else:
new_img = self.copy()
new_img._data = new_img._data - other
return new_img
def __add__(self, other):
if isinstance(other, Image):
new_img = self[other]
new_img._data = new_img._data + other[self]._data
return new_img
else:
new_img = self.copy()
new_img._data = new_img._data + other
return new_img
def __iadd__(self, other):
if isinstance(other, Image):
self._data = backend.add_at_indices(
self._data,
self.get_indices(other.window),
other._data[other.get_indices(self.window)],
)
else:
self._data = self._data + other
return self
def __isub__(self, other):
if isinstance(other, Image):
self._data = backend.add_at_indices(
self._data,
self.get_indices(other.window),
-other._data[other.get_indices(self.window)],
)
else:
self._data = self._data - other
return self
def __getitem__(self, *args):
if len(args) == 1 and isinstance(args[0], (Image, Window)):
return self.get_window(args[0])
return super().__getitem__(*args)
[docs]
class ImageList(Module):
"""A class to represent a list of images.
This is useful for operations that involve multiple images, mostly for joint
modelling. The ImageList class provides methods for matching images based on
their identity, and for applying operations to all images in the list while
preserving their individual properties. For certain applications (the
``flatten`` method) you can use ImageList and Image objects
interchangably/agnostically.
"""
def __init__(self, images: list[Image], **kwargs):
super().__init__(**kwargs)
self.images = list(images)
if not all(isinstance(image, Image) for image in self.images):
raise InvalidImage(
f"ImageList can only hold Image objects, not {tuple(type(image) for image in self.images)}"
)
@property
def data(self):
return tuple(image.data for image in self.images)
@property
def _data(self):
return tuple(image._data for image in self.images)
@_data.setter
def _data(self, value):
if len(value) != len(self.images):
raise ValueError(
f"Expected an object of length {len(self.images)} for _data, but got {type(value)} of length {len(value)}"
)
for image, data in zip(self.images, value):
image._data = data
@property
def window(self):
return WindowList(tuple(image.window for image in self.images))
[docs]
def copy(self):
return self.__class__(
tuple(image.copy() for image in self.images),
)
[docs]
def blank_copy(self):
return self.__class__(
tuple(image.blank_copy() for image in self.images),
)
[docs]
def get_window(self, other: "ImageList"):
return self.__class__(
tuple(image[win] for image, win in zip(self.images, other.images)),
)
[docs]
def index(self, other: Image):
for i, image in enumerate(self.images):
if other.identity == image.identity:
return i
else:
raise IndexError(
f"Could not find identity match between image list {self.name} and input image {other.name}"
)
[docs]
def match_indices(self, other: "ImageList"):
"""Match the indices of the images in this list with those in another ImageList."""
indices = []
for other_image in other.images:
try:
i = self.index(other_image)
except IndexError:
continue
indices.append(i)
return indices
[docs]
def to(self, dtype=None, device=None):
if dtype is not None:
dtype = config.DTYPE
if device is not None:
device = config.DEVICE
super().to(dtype=dtype, device=device)
return self
[docs]
def flatten(self, attribute: str = "data") -> ArrayLike:
return backend.concatenate(tuple(image.flatten(attribute) for image in self.images))
def __sub__(self, other):
if isinstance(other, ImageList):
new_list = []
for other_image in other.images:
i = self.index(other_image)
self_image = self.images[i]
new_list.append(self_image - other_image)
return self.__class__(new_list)
else:
raise ValueError("Subtraction of ImageList only works with another ImageList object!")
def __add__(self, other):
if isinstance(other, ImageList):
new_list = []
for other_image in other.images:
try:
i = self.index(other_image)
except IndexError:
continue
self_image = self.images[i]
new_list.append(self_image + other_image)
return self.__class__(new_list)
else:
raise ValueError("Addition of ImageList only works with another ImageList object!")
def __isub__(self, other):
if isinstance(other, ImageList):
for other_image in other.images:
try:
i = self.index(other_image)
except IndexError:
continue
self.images[i] -= other_image
elif isinstance(other, Image):
i = self.index(other)
self.images[i] -= other
else:
raise ValueError("Subtraction of ImageList only works with another ImageList object!")
return self
def __iadd__(self, other):
if isinstance(other, ImageList):
for other_image in other.images:
try:
i = self.index(other_image)
except IndexError:
continue
self.images[i] += other_image
elif isinstance(other, Image):
i = self.index(other)
self.images[i] += other
else:
raise ValueError("Addition of ImageList only works with another ImageList object!")
return self
def __getitem__(self, *args):
if len(args) == 1:
if isinstance(args[0], ImageList):
new_list = []
for other_image in args[0].images:
i = self.index(other_image)
new_list.append(self.images[i].get_window(other_image))
return self.__class__(new_list)
elif isinstance(args[0], WindowList):
new_list = []
for other_window in args[0].windows:
i = self.index(other_window.image)
new_list.append(self.images[i].get_window(other_window))
return self.__class__(new_list)
elif isinstance(args[0], Image):
i = self.index(args[0])
return self.images[i].get_window(args[0])
elif isinstance(args[0], Window):
i = self.index(args[0].image)
return self.images[i].get_window(args[0])
elif isinstance(args[0], int):
return self.images[args[0]]
super().__getitem__(*args)
def __iter__(self):
return (img for img in self.images)
[docs]
class ImageBatchMixin:
"""Specialized ImageList type where the images are all the same size.
An ImageBatch has restrictions on the shape of the images it can hold, but
in exchange it allows vectorized operations over a batch of images.
Some notes to keep in mind:
- All the images must be the regular image type (i.e. not SIP or CMOS images yet).
- All the images must have the same shape, otherwise the batch operations will not work.
- The ImageBatch does not itself accelerate any operations, it facilitates the BatchSceneModel.
- Otherwise the ImageBatch behaves like a regular ImageList.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not all(isinstance(image, Image) for image in self.images):
raise InvalidImage(
f"ImageBatch can only hold Image objects, not {tuple(type(image) for image in self.images)}"
)
if not all(isinstance(image, self.images[0].__class__) for image in self.images):
raise InvalidImage(
f"ImageBatch images must all be of the same type, not {tuple(type(image) for image in self.images)}"
)
if not all(image.data.shape == self.images[0].data.shape for image in self.images):
raise InvalidImage(
f"All images in an ImageBatch must have the same shape, but got shapes {tuple(image.data.shape for image in self.images)}"
)
@property
def data(self):
return backend.stack(tuple(image.data for image in self.images), dim=0)
@ImageList._data.getter
def _data(self):
return backend.stack(tuple(image._data for image in self.images), dim=0)
@property
def window(self):
return WindowBatch(tuple(image.window for image in self.images))
@property
def crval(self):
return backend.stack(tuple(image.crval.value for image in self.images), dim=0)
@property
def crtan(self):
return backend.stack(tuple(image.crtan.value for image in self.images), dim=0)
@property
def CD(self):
return backend.stack(tuple(image.CD.value for image in self.images), dim=0)
@property
def crpix(self):
return backend.as_array(
np.stack(tuple(image.crpix for image in self.images), axis=0),
dtype=config.DTYPE,
device=config.DEVICE,
)