Source code for astrophot.image.window

from typing import Union, Tuple, List

import numpy as np

from ..errors import InvalidWindow

__all__ = ("Window", "WindowList", "WindowBatch")


[docs] class Window: def __init__( self, window: Union[Tuple[int, int, int, int], Tuple[Tuple[int, int], Tuple[int, int]]], image: "Image", ): self.extent = window self.image = image @property def identity(self): return self.image.identity @property def crpix(self): return self.image.crpix @property def shape(self): return (self.i_high - self.i_low, self.j_high - self.j_low) @property def extent(self): return (self.i_low, self.i_high, self.j_low, self.j_high) @extent.setter def extent( self, value: Union[Tuple[int, int, int, int], Tuple[Tuple[int, int], Tuple[int, int]]] ): if len(value) == 4: self.i_low, self.i_high, self.j_low, self.j_high = value elif len(value) == 2: self.i_low, self.j_low = value[0] self.i_high, self.j_high = value[1] else: raise ValueError( "Extent must be formatted as (i_low, i_high, j_low, j_high) or ((i_low, j_low), (i_high, j_high))" )
[docs] def chunk(self, chunk_size: int) -> List["Window"]: # number of pixels on each axis px = self.i_high - self.i_low py = self.j_high - self.j_low # total number of chunks desired chunk_tot = int(np.ceil((px * py) / chunk_size)) # number of chunks on each axis cx = int(np.ceil(np.sqrt(chunk_tot * px / py))) cy = int(np.ceil(chunk_tot / cx)) # number of pixels on each axis per chunk stepx = int(np.ceil(px / cx)) stepy = int(np.ceil(py / cy)) # create the windows windows = [] for i in range(self.i_low, self.i_high, stepx): for j in range(self.j_low, self.j_high, stepy): i_high = min(i + stepx, self.i_high) j_high = min(j + stepy, self.j_high) windows.append(Window((i, i_high, j, j_high), self.image)) return windows
[docs] def pad(self, pad: int): self.i_low -= pad self.i_high += pad self.j_low -= pad self.j_high += pad
[docs] def copy(self): return Window((self.i_low, self.i_high, self.j_low, self.j_high), self.image)
def __or__(self, other: "Window"): if not isinstance(other, Window): raise TypeError(f"Cannot combine Window with {type(other)}") if self.image != other.image: raise InvalidWindow( f"Cannot combine Windows from different images: {self.image.identity} and {other.image.identity}" ) new_i_low = min(self.i_low, other.i_low) new_i_high = max(self.i_high, other.i_high) new_j_low = min(self.j_low, other.j_low) new_j_high = max(self.j_high, other.j_high) return Window((new_i_low, new_i_high, new_j_low, new_j_high), self.image) def __ior__(self, other: "Window"): if not isinstance(other, Window): raise TypeError(f"Cannot combine Window with {type(other)}") if self.image != other.image: raise InvalidWindow( f"Cannot combine Windows from different images: {self.image.identity} and {other.image.identity}" ) self.i_low = min(self.i_low, other.i_low) self.i_high = max(self.i_high, other.i_high) self.j_low = min(self.j_low, other.j_low) self.j_high = max(self.j_high, other.j_high) return self def __and__(self, other: "Window"): if not isinstance(other, Window): raise TypeError(f"Cannot intersect Window with {type(other)}") if self.image.identity != other.image.identity: raise InvalidWindow( f"Cannot combine Windows from different images: {self.image.identity} and {other.image.identity}" ) if ( self.i_high <= other.i_low or self.i_low >= other.i_high or self.j_high <= other.j_low or self.j_low >= other.j_high ): return Window((0, 0, 0, 0), self.image) # fixme handle crpix new_i_low = max(self.i_low, other.i_low) new_i_high = min(self.i_high, other.i_high) new_j_low = max(self.j_low, other.j_low) new_j_high = min(self.j_high, other.j_high) return Window((new_i_low, new_i_high, new_j_low, new_j_high), self.image) def __str__(self): return f"Window({self.i_low}, {self.i_high}, {self.j_low}, {self.j_high})"
[docs] class WindowList: def __init__(self, windows: list[Window]): if not all(isinstance(window, Window) for window in windows): raise InvalidWindow( f"Window_List can only hold Window objects, not {tuple(type(window) for window in windows)}" ) self.windows = windows
[docs] def index(self, other: Window) -> int: for i, window in enumerate(self.windows): if other.identity == window.identity: return i else: raise IndexError("Could not find identity match between window list and input window")
def __and__(self, other: "WindowList"): if not isinstance(other, WindowList): raise TypeError(f"Cannot intersect WindowList with {type(other)}") if len(self.windows) == 0 or len(other.windows) == 0: return WindowList([]) new_windows = [] for other_window in other.windows: try: i = self.index(other_window) except IndexError: continue # skip if the window is not in self.windows new_windows.append(self.windows[i] & other_window) return WindowList(new_windows) def __getitem__(self, index): return self.windows[index] def __len__(self): return len(self.windows) def __iter__(self): return iter(self.windows)
[docs] class WindowBatch(WindowList): def __init__(self, windows: list[Window]): super().__init__(windows) if not all(window.shape == self.windows[0].shape for window in self.windows): raise InvalidWindow("All windows in a WindowBatch must have the same shape.")
[docs] def origin_shifter(self, other: Window) -> np.ndarray: """Returns the shift in origin between this window batch and another window.""" if not isinstance(other, Window): raise TypeError(f"Can only compute origin shift with a Window, not {type(other)}") shifts = [] for window in self.windows: shift_i = other.i_low - window.i_low shift_j = other.j_low - window.j_low shifts.append((shift_i, shift_j)) return np.array(shifts)