Source code for astrophot.image.image_object

from typing import Optional, Tuple, Union

import numpy as np
from astropy.wcs import WCS as AstropyWCS
from astropy.io import fits

from ..param import Module, Param, forward
from .. import config
from ..backend_obj import backend, ArrayLike
from ..utils.conversions.units import deg_to_arcsec, arcsec_to_deg
from .window import Window, WindowList, WindowBatch
from ..errors import InvalidImage, SpecificationConflict

# from .base import BaseImage
from . import func

__all__ = ["Image", "ImageList"]


[docs] class Image(Module): """Core class to represent images with pixel values, pixel scale, and a window defining the spatial coordinates on the sky. It supports arithmetic operations with other image objects while preserving logical image boundaries. It also provides methods for determining the coordinate locations of pixels **Args:** - `data`: The image data as a tensor of pixel values. If not provided, a tensor of zeros will be created. - `zeropoint`: The zeropoint of the image, which is used to convert from pixel flux to magnitude. - `crpix`: The reference pixel coordinates in the image, which is used to convert from pixel coordinates to tangent plane coordinates. - `pixelscale`: The side length of a pixel, used to create a simple diagonal CD matrix. - `wcs`: An optional Astropy WCS object to initialize the image. - `filename`: The filename to load the image from. If provided, the image will be loaded from the file. - `hduext`: The HDU extension to load from the FITS file specified in `filename`. - `identity`: An optional identity string for the image. these parameters are added to the optimization model: **Parameters:** - `crval`: The reference coordinate of the image in degrees [RA, DEC]. - `crtan`: The tangent plane coordinate of the image in arcseconds [x, y]. - `CD`: The coordinate transformation matrix in arcseconds/pixel. """ expect_ctype = (("RA---TAN",), ("DEC--TAN",)) base_scale = 1.0 def __init__( self, *, data: Optional[ArrayLike] = None, CD: Optional[Union[float, ArrayLike]] = None, zeropoint: Optional[Union[float, ArrayLike]] = None, crpix: Union[ArrayLike, tuple] = (0.0, 0.0), crtan: Union[ArrayLike, tuple] = (0.0, 0.0), crval: Union[ArrayLike, tuple] = (0.0, 0.0), pixelscale: Optional[Union[ArrayLike, float]] = 1.0, wcs: Optional[AstropyWCS] = None, filename: Optional[str] = None, hduext: int = 0, identity: str = None, name: Optional[str] = None, _data: Optional[ArrayLike] = None, ): super().__init__(name=name) if _data is None: self.data = data # units: flux else: self._data = _data self.crtan = Param( "crtan", crtan, shape=(2,), units="arcsec", dtype=config.DTYPE, device=config.DEVICE, ) self.zeropoint = zeropoint if identity is None: self._identity = id(self) else: self._identity = identity if wcs is not None: if wcs.wcs.ctype[0] not in self.expect_ctype[0]: config.logger.warning( "Astropy WCS not tangent plane coordinate system! May not be compatible with AstroPhot." ) if wcs.wcs.ctype[1] not in self.expect_ctype[1]: config.logger.warning( "Astropy WCS not tangent plane coordinate system! May not be compatible with AstroPhot." ) crval = wcs.wcs.crval crpix = np.array(wcs.wcs.crpix)[::-1] - 1 # handle FITS 1-indexing if CD is not None: config.logger.warning("WCS CD set with supplied WCS, ignoring user supplied CD!") CD = deg_to_arcsec * wcs.pixel_scale_matrix # set the data self.crval = Param( "crval", crval, shape=(2,), units="deg", dtype=config.DTYPE, device=config.DEVICE ) self.crpix = crpix if isinstance(CD, (float, int)): CD = np.array([[CD, 0.0], [0.0, CD]], dtype=np.float64) elif CD is None: CD = np.array([[pixelscale, 0.0], [0.0, pixelscale]], dtype=np.float64) self.CD = Param( "CD", CD, shape=(2, 2), units="arcsec/pixel", dtype=config.DTYPE, device=config.DEVICE, ) if filename is not None: self.load(filename, hduext=hduext) return @property def identity(self): return self._identity @property def data(self): """The image data, which is a tensor of pixel values.""" return backend.transpose(self._data, 1, 0) @data.setter def data(self, value: Optional[ArrayLike]): """Set the image data. If value is None, the data is initialized to an empty tensor.""" if value is None: self._data = backend.empty((0, 0), dtype=config.DTYPE, device=config.DEVICE) else: # Transpose since pytorch uses (j, i) indexing when (i, j) is more natural for coordinates self._data = backend.transpose( backend.as_array(value, dtype=config.DTYPE, device=config.DEVICE), 1, 0 ) @property def crpix(self) -> ArrayLike: """The reference pixel coordinates in the image, which is used to convert from pixel coordinates to tangent plane coordinates.""" return self._crpix @crpix.setter def crpix(self, value: Union[ArrayLike, tuple]): self._crpix = np.array(value, dtype=np.float64) @property def zeropoint(self) -> ArrayLike: """The zeropoint of the image, which is used to convert from pixel flux to magnitude.""" return self._zeropoint @zeropoint.setter def zeropoint(self, value): """Set the zeropoint of the image.""" if value is None: self._zeropoint = None else: self._zeropoint = backend.as_array(value, dtype=config.DTYPE, device=config.DEVICE) @property def window(self) -> Window: return Window(window=((0, 0), self._data.shape[:2]), image=self) @property def center(self): shape = backend.as_array(self._data.shape[:2], dtype=config.DTYPE, device=config.DEVICE) return backend.stack(self.pixel_to_plane(*((shape - 1) / 2))) # @property # def shape(self): # """The shape of the image data.""" # return self.data.shape @property @forward def pixel_area(self, CD): """The area inside a pixel in arcsec^2""" return backend.abs(backend.linalg.det(CD)) @property @forward def pixelscale(self): """The approximate side length of a pixel, which is just sqrt(pixel_area). For square pixels this is the actual pixel length, for rectangular pixels it is a kind of average. The pixelscale is not used for exact calculations and instead sets a size scale within an image. """ return backend.sqrt(self.pixel_area)
[docs] @forward def pixel_collecting_area(self, I_, J_, upsample, CD): """The area of the sky that each pixel collects light from, in arcsec^2. This is just the pixel area, but can be overridden for certain types of images (e.g. SIP images) where the pixel collecting area is not the same as the pixel area.""" return backend.abs(backend.linalg.det(CD)) / upsample**2
@property def flip_ra_axis(self): return np.linalg.det(self.CD.npvalue) < 0
[docs] @forward def pixel_to_plane( self, i: ArrayLike, j: ArrayLike, crtan: ArrayLike, CD: ArrayLike, _crpix: Optional[ArrayLike] = None, ) -> Tuple[ArrayLike, ArrayLike]: crpix = self.crpix if _crpix is None else _crpix return func.pixel_to_plane_linear(i, j, *crpix, CD, *crtan)
[docs] @forward def plane_to_pixel( self, x: ArrayLike, y: ArrayLike, crtan: ArrayLike, CD: ArrayLike, _crpix: Optional[ArrayLike] = None, ) -> Tuple[ArrayLike, ArrayLike]: crpix = self.crpix if _crpix is None else _crpix return func.plane_to_pixel_linear(x, y, *crpix, CD, *crtan)
[docs] @forward def plane_to_world( self, x: ArrayLike, y: ArrayLike, crval: ArrayLike ) -> Tuple[ArrayLike, ArrayLike]: return func.plane_to_world_gnomonic(x, y, *crval)
[docs] @forward def world_to_plane( self, ra: ArrayLike, dec: ArrayLike, crval: ArrayLike ) -> Tuple[ArrayLike, ArrayLike]: return func.world_to_plane_gnomonic(ra, dec, *crval)
[docs] @forward def world_to_pixel(self, ra: ArrayLike, dec: ArrayLike) -> Tuple[ArrayLike, ArrayLike]: """A wrapper which applies :meth:`world_to_plane` then :meth:`plane_to_pixel`, see those methods for further information. """ return self.plane_to_pixel(*self.world_to_plane(ra, dec))
[docs] @forward def pixel_to_world(self, i: ArrayLike, j: ArrayLike) -> Tuple[ArrayLike, ArrayLike]: """A wrapper which applies :meth:`pixel_to_plane` then :meth:`plane_to_world`, see those methods for further information. """ return self.plane_to_world(*self.pixel_to_plane(i, j))
[docs] def pixel_center_meshgrid(self, window=None, pad=0, upsample=1) -> Tuple[ArrayLike, ArrayLike]: """Get a meshgrid of pixel coordinates in the image, centered on the pixel grid.""" if window is None: window = self.window return func.pixel_center_meshgrid(window.extent, pad, upsample, config.DTYPE, config.DEVICE)
[docs] def pixel_corner_meshgrid(self, window=None, pad=0, upsample=1) -> Tuple[ArrayLike, ArrayLike]: """Get a meshgrid of pixel coordinates in the image, with corners at the pixel grid.""" if window is None: window = self.window return func.pixel_corner_meshgrid(window.extent, pad, upsample, config.DTYPE, config.DEVICE)
[docs] def pixel_simpsons_meshgrid( self, window=None, pad=0, upsample=1 ) -> Tuple[ArrayLike, ArrayLike]: """Get a meshgrid of pixel coordinates in the image, with Simpson's rule sampling.""" if window is None: window = self.window return func.pixel_simpsons_meshgrid( window.extent, pad, upsample, config.DTYPE, config.DEVICE )
[docs] def pixel_quad_meshgrid( self, window=None, pad=0, upsample=1, order=3 ) -> Tuple[ArrayLike, ArrayLike]: """Get a meshgrid of pixel coordinates in the image, with quadrature sampling.""" if window is None: window = self.window return func.pixel_quad_meshgrid( window.extent, pad, upsample, config.DTYPE, config.DEVICE, order=order )
[docs] @forward def coordinate_center_meshgrid(self) -> Tuple[ArrayLike, ArrayLike]: """Get a meshgrid of coordinate locations in the image, centered on the pixel grid.""" i, j = self.pixel_center_meshgrid() return self.pixel_to_plane(i, j)
[docs] @forward def coordinate_corner_meshgrid(self) -> Tuple[ArrayLike, ArrayLike]: """Get a meshgrid of coordinate locations in the image, with corners at the pixel grid.""" i, j = self.pixel_corner_meshgrid() return self.pixel_to_plane(i, j)
[docs] @forward def coordinate_simpsons_meshgrid(self) -> Tuple[ArrayLike, ArrayLike]: """Get a meshgrid of coordinate locations in the image, with Simpson's rule sampling.""" i, j = self.pixel_simpsons_meshgrid() return self.pixel_to_plane(i, j)
[docs] @forward def coordinate_quad_meshgrid(self, order=3) -> Tuple[ArrayLike, ArrayLike]: """Get a meshgrid of coordinate locations in the image, with quadrature sampling.""" i, j, _ = self.pixel_quad_meshgrid(order=order) return self.pixel_to_plane(i, j)
[docs] def copy_kwargs(self, **kwargs) -> dict: kwargs = { "_data": backend.copy(self._data), "CD": self.CD.value, "crpix": self.crpix, "crval": self.crval.value, "crtan": self.crtan.value, "zeropoint": self.zeropoint, "identity": self.identity, "name": self.name, **kwargs, } return kwargs
[docs] def copy(self, **kwargs): """Produce a copy of this image with all of the same properties. This can be used when one wishes to make temporary modifications to an image and then will want the original again. """ return self.__class__(**self.copy_kwargs(**kwargs))
[docs] def blank_copy(self, **kwargs): """Produces a blank copy of the image which has the same properties except that its data is now filled with zeros. """ kwargs = { "_data": backend.zeros_like(self._data), **kwargs, } return self.copy(**kwargs)
[docs] def crop(self, pixels: Union[int, Tuple[int, int], Tuple[int, int, int, int]], **kwargs): """Crop the image by the number of pixels given. This will crop the image in all four directions by the number of pixels given. given data shape (N, M) the new shape will be: crop - int: crop the same number of pixels on all sides. new shape (N - 2*crop, M - 2*crop) crop - (int, int): crop each dimension by the number of pixels given. new shape (N - 2*crop[1], M - 2*crop[0]) crop - (int, int, int, int): crop each side by the number of pixels given assuming (x low, x high, y low, y high). new shape (N - crop[2] - crop[3], M - crop[0] - crop[1]) """ if np.all(np.array(pixels) == 0): return self if isinstance(pixels, int): data = self._data[ pixels : self._data.shape[0] - pixels, pixels : self._data.shape[1] - pixels, ] crpix = self.crpix - pixels elif len(pixels) == 1: # same crop in all dimension crop = pixels if isinstance(pixels, int) else pixels[0] data = self._data[ crop : self._data.shape[0] - crop, crop : self._data.shape[1] - crop, ] crpix = self.crpix - crop elif len(pixels) == 2: # different crop in each dimension data = self._data[ pixels[0] : self._data.shape[0] - pixels[0], pixels[1] : self._data.shape[1] - pixels[1], ] crpix = self.crpix - pixels elif len(pixels) == 4: # different crop on all sides data = self._data[ pixels[0] : self._data.shape[0] - pixels[1], pixels[2] : self._data.shape[1] - pixels[3], ] crpix = self.crpix - pixels[0::2] else: raise ValueError( f"Invalid crop shape {pixels}, must be (int,), (int, int), or (int, int, int, int)!" ) return self.copy(_data=data, crpix=crpix, **kwargs)
[docs] def reduce(self, scale: int, **kwargs): """This operation will downsample an image by the factor given. If scale = 2 then 2x2 blocks of pixels will be summed together to form individual larger pixels. A new image object will be returned with the appropriate pixelscale and data tensor. Note that the window does not change in this operation since the pixels are condensed, but the pixel size is increased correspondingly. **Args:** - `scale` (int): The scale factor by which to reduce the image. """ if not isinstance(scale, int) and not ( isinstance(scale, ArrayLike) and scale.dtype is backend.int32 ): raise SpecificationConflict(f"Reduce scale must be an integer! not {type(scale)}") if scale == 1: return self MS = self._data.shape[0] // scale NS = self._data.shape[1] // scale data = self._data[: MS * scale, : NS * scale].reshape(MS, scale, NS, scale).sum(axis=(1, 3)) CD = self.CD.value * scale crpix = (self.crpix + 0.5) / scale - 0.5 return self.copy( _data=data, CD=CD, crpix=crpix, **kwargs, )
[docs] def to(self, dtype=None, device=None): if dtype is None: dtype = config.DTYPE if device is None: device = config.DEVICE super().to(dtype=dtype, device=device) self._data = backend.to(self._data, dtype=dtype, device=device) if self.zeropoint is not None: self.zeropoint = backend.to(self.zeropoint, dtype=dtype, device=device) return self
[docs] def flatten(self, attribute: str = "data") -> ArrayLike: return backend.flatten(getattr(self, attribute))
[docs] def fits_info(self) -> dict: return { "CTYPE1": "RA---TAN", "CTYPE2": "DEC--TAN", "CRVAL1": self.crval.value[0].item(), "CRVAL2": self.crval.value[1].item(), "CRPIX1": self.crpix[0] + 1, "CRPIX2": self.crpix[1] + 1, "CRTAN1": self.crtan.value[0].item(), "CRTAN2": self.crtan.value[1].item(), "CD1_1": self.CD.value[0][0].item() * arcsec_to_deg, "CD1_2": self.CD.value[0][1].item() * arcsec_to_deg, "CD2_1": self.CD.value[1][0].item() * arcsec_to_deg, "CD2_2": self.CD.value[1][1].item() * arcsec_to_deg, "MAGZP": self.zeropoint.item() if self.zeropoint is not None else -999, "IDNTY": self.identity, }
[docs] def fits_images(self): return [ fits.PrimaryHDU( backend.to_numpy(backend.transpose(self._data, 1, 0)), header=fits.Header(self.fits_info()), ) ]
[docs] def get_astropywcs(self, **kwargs): kwargs = { "NAXIS": 2, "NAXIS1": self.shape[0].item(), "NAXIS2": self.shape[1].item(), **self.fits_info(), **kwargs, } return AstropyWCS(kwargs)
[docs] def save(self, filename: str): hdulist = fits.HDUList(self.fits_images()) hdulist.writeto(filename, overwrite=True)
[docs] def load(self, filename: Union[str, fits.HDUList], hduext: int = 0): """Load an image from a FITS file. This will load the primary HDU and set the data, CD, crpix, crval, and crtan attributes accordingly. If the WCS is not tangent plane, it will warn the user. """ if isinstance(filename, str): hdulist = fits.open(filename) else: hdulist = filename self.data = np.array(hdulist[hduext].data, dtype=np.float64) self.CD = ( np.array( ( (hdulist[hduext].header["CD1_1"], hdulist[hduext].header["CD1_2"]), (hdulist[hduext].header["CD2_1"], hdulist[hduext].header["CD2_2"]), ), dtype=np.float64, ) * deg_to_arcsec ) self.crpix = (hdulist[hduext].header["CRPIX1"] - 1, hdulist[hduext].header["CRPIX2"] - 1) self.crval = (hdulist[hduext].header["CRVAL1"], hdulist[hduext].header["CRVAL2"]) if "CRTAN1" in hdulist[hduext].header and "CRTAN2" in hdulist[hduext].header: self.crtan = (hdulist[hduext].header["CRTAN1"], hdulist[hduext].header["CRTAN2"]) if "MAGZP" in hdulist[hduext].header and hdulist[hduext].header["MAGZP"] > -998: self.zeropoint = hdulist[hduext].header["MAGZP"] self._identity = hdulist[hduext].header.get("IDNTY", str(id(self))) return hdulist
[docs] def corners( self, ) -> Tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike]: pixel_lowleft = backend.make_array((-0.5, -0.5), dtype=config.DTYPE, device=config.DEVICE) pixel_lowright = backend.make_array( (self._data.shape[0] - 0.5, -0.5), dtype=config.DTYPE, device=config.DEVICE ) pixel_upleft = backend.make_array( (-0.5, self._data.shape[1] - 0.5), dtype=config.DTYPE, device=config.DEVICE ) pixel_upright = backend.make_array( (self._data.shape[0] - 0.5, self._data.shape[1] - 0.5), dtype=config.DTYPE, device=config.DEVICE, ) lowleft = self.pixel_to_plane(*pixel_lowleft) lowright = self.pixel_to_plane(*pixel_lowright) upleft = self.pixel_to_plane(*pixel_upleft) upright = self.pixel_to_plane(*pixel_upright) return (lowleft, lowright, upright, upleft)
[docs] def get_indices(self, other: Window): if other.image is self: return slice(max(0, other.i_low), min(self._data.shape[0], other.i_high)), slice( max(0, other.j_low), min(self._data.shape[1], other.j_high) ) if other.image.identity != self.identity: config.logger.warning( f"Attempting to match windows with different images! Window image: {other.image.name}, {other.image.identity}, self image: {self.name}, {self.identity}. This may fail unless you are sure the two images are on the same pixel grid." ) shift = np.round(self.crpix - other.crpix).astype(int) return slice( min(max(0, other.i_low + shift[0]), self._data.shape[0]), max(0, min(other.i_high + shift[0], self._data.shape[0])), ), slice( min(max(0, other.j_low + shift[1]), self._data.shape[1]), max(0, min(other.j_high + shift[1], self._data.shape[1])), )
[docs] def get_other_indices(self, other: Window): if other.image == self: # fixme check identity, or check "is"? shape = other.shape return slice( max(0, -other.i_low), min(self._data.shape[0] - other.i_low, shape[0]) ), slice(max(0, -other.j_low), min(self._data.shape[1] - other.j_low, shape[1])) raise ValueError()
[docs] def get_window(self, other: Union[Window, "Image"], indices=None, **kwargs): """Get a new image object which is a window of this image corresponding to the other image's window. This will return a new image object with the same properties as this one, but with the data cropped to the other image's window. """ if indices is None: indices = self.get_indices(other if isinstance(other, Window) else other.window) new_img = self.copy( _data=self._data[indices], crpix=self.crpix - np.array((indices[0].start, indices[1].start)), **kwargs, ) return new_img
def __sub__(self, other): if isinstance(other, Image): new_img = self[other] new_img._data = new_img._data - other[self]._data return new_img else: new_img = self.copy() new_img._data = new_img._data - other return new_img def __add__(self, other): if isinstance(other, Image): new_img = self[other] new_img._data = new_img._data + other[self]._data return new_img else: new_img = self.copy() new_img._data = new_img._data + other return new_img def __iadd__(self, other): if isinstance(other, Image): self._data = backend.add_at_indices( self._data, self.get_indices(other.window), other._data[other.get_indices(self.window)], ) else: self._data = self._data + other return self def __isub__(self, other): if isinstance(other, Image): self._data = backend.add_at_indices( self._data, self.get_indices(other.window), -other._data[other.get_indices(self.window)], ) else: self._data = self._data - other return self def __getitem__(self, *args): if len(args) == 1 and isinstance(args[0], (Image, Window)): return self.get_window(args[0]) return super().__getitem__(*args)
# fixme, make image lists infinitely nestable, need to merge "index" and "match_indices" in some consistent way
[docs] class ImageList(Module): def __init__(self, images: list[Image], **kwargs): super().__init__(**kwargs) self.images = list(images) if not all(isinstance(image, Image) for image in self.images): raise InvalidImage( f"Image_List can only hold Image objects, not {tuple(type(image) for image in self.images)}" ) @property def data(self): return tuple(image.data for image in self.images) @property def _data(self): return tuple(image._data for image in self.images) @_data.setter def _data(self, value): if len(value) != len(self.images): raise ValueError( f"Expected an object of length {len(self.images)} for _data, but got {type(value)} of length {len(value)}" ) for image, data in zip(self.images, value): image._data = data @property def window(self): return WindowList(tuple(image.window for image in self.images))
[docs] def copy(self): return self.__class__( tuple(image.copy() for image in self.images), )
[docs] def blank_copy(self): return self.__class__( tuple(image.blank_copy() for image in self.images), )
[docs] def get_window(self, other: "ImageList"): return self.__class__( tuple(image[win] for image, win in zip(self.images, other.images)), )
[docs] def index(self, other: Image): for i, image in enumerate(self.images): if other.identity == image.identity: return i else: raise IndexError( f"Could not find identity match between image list {self.name} and input image {other.name}" )
[docs] def match_indices(self, other: "ImageList"): """Match the indices of the images in this list with those in another Image_List.""" indices = [] for other_image in other.images: try: i = self.index(other_image) except IndexError: continue indices.append(i) return indices
[docs] def to(self, dtype=None, device=None): if dtype is not None: dtype = config.DTYPE if device is not None: device = config.DEVICE super().to(dtype=dtype, device=device) return self
[docs] def flatten(self, attribute: str = "data") -> ArrayLike: return backend.concatenate(tuple(image.flatten(attribute) for image in self.images))
def __sub__(self, other): if isinstance(other, ImageList): new_list = [] for other_image in other.images: i = self.index(other_image) self_image = self.images[i] new_list.append(self_image - other_image) return self.__class__(new_list) else: raise ValueError("Subtraction of Image_List only works with another Image_List object!") def __add__(self, other): if isinstance(other, ImageList): new_list = [] for other_image in other.images: try: i = self.index(other_image) except IndexError: continue self_image = self.images[i] new_list.append(self_image + other_image) return self.__class__(new_list) else: raise ValueError("Addition of Image_List only works with another Image_List object!") def __isub__(self, other): if isinstance(other, ImageList): for other_image in other.images: try: i = self.index(other_image) except IndexError: continue self.images[i] -= other_image elif isinstance(other, Image): i = self.index(other) self.images[i] -= other else: raise ValueError("Subtraction of Image_List only works with another Image_List object!") return self def __iadd__(self, other): if isinstance(other, ImageList): for other_image in other.images: try: i = self.index(other_image) except IndexError: continue self.images[i] += other_image elif isinstance(other, Image): i = self.index(other) self.images[i] += other else: raise ValueError("Addition of Image_List only works with another Image_List object!") return self def __getitem__(self, *args): if len(args) == 1: if isinstance(args[0], ImageList): new_list = [] for other_image in args[0].images: i = self.index(other_image) new_list.append(self.images[i].get_window(other_image)) return self.__class__(new_list) elif isinstance(args[0], WindowList): new_list = [] for other_window in args[0].windows: i = self.index(other_window.image) new_list.append(self.images[i].get_window(other_window)) return self.__class__(new_list) elif isinstance(args[0], Image): i = self.index(args[0]) return self.images[i].get_window(args[0]) elif isinstance(args[0], Window): i = self.index(args[0].image) return self.images[i].get_window(args[0]) elif isinstance(args[0], int): return self.images[args[0]] super().__getitem__(*args) def __iter__(self): return (img for img in self.images)
[docs] class ImageBatchMixin: """Specialized ImageList type where the images are all the same size. An ImageBatch has restrictions on the shape of the images it can hold, but in exchange it allows vectorized operations over a batch of images. Some notes to keep in mind: - All the images must be the regular image type (i.e. not SIP or CMOS images yet). - All the images must have the same shape, otherwise the batch operations will not work. - The ImageBatch does not itself accelerate any operations, it facilitates the BatchSceneModel. - Otherwise the ImageBatch behaves like a regular ImageList. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not all(isinstance(image, Image) for image in self.images): raise InvalidImage( f"ImageBatch can only hold Image objects, not {tuple(type(image) for image in self.images)}" ) if not all(isinstance(image, self.images[0].__class__) for image in self.images): raise InvalidImage( f"ImageBatch images must all be of the same type, not {tuple(type(image) for image in self.images)}" ) if not all(image.data.shape == self.images[0].data.shape for image in self.images): raise InvalidImage( f"All images in an ImageBatch must have the same shape, but got shapes {tuple(image.data.shape for image in self.images)}" ) @property def data(self): return backend.stack(tuple(image.data for image in self.images), dim=0) @ImageList._data.getter def _data(self): return backend.stack(tuple(image._data for image in self.images), dim=0) @property def window(self): return WindowBatch(tuple(image.window for image in self.images)) @property def crval(self): return backend.stack(tuple(image.crval.value for image in self.images), dim=0) @property def crtan(self): return backend.stack(tuple(image.crtan.value for image in self.images), dim=0) @property def CD(self): return backend.stack(tuple(image.CD.value for image in self.images), dim=0) @property def crpix(self): return backend.as_array( np.stack(tuple(image.crpix for image in self.images), axis=0), dtype=config.DTYPE, device=config.DEVICE, )