from typing import Tuple
import numpy as np
import torch
from .sky_model_object import SkyModel
from ..utils.decorators import ignore_numpy_warnings, combine_docstrings
from ..utils.interpolate import interp2d
from ..param import forward
from ..backend_obj import backend, ArrayLike
from . import func
from ..utils.initialize import polar_decomposition
__all__ = ["BilinearSky"]
[docs]
@combine_docstrings
class BilinearSky(SkyModel):
"""Sky background model using a coarse bilinear grid for the sky flux.
**Parameters:**
- `I`: sky brightness grid
- `PA`: position angle of the sky grid in radians.
- `scale`: scale of the sky grid in arcseconds per grid unit.
"""
_model_type = "bilinear"
_parameter_specs = {
"I": {"units": "flux/arcsec^2", "shape": (None, None), "dynamic": True},
"PA": {"units": "radians", "shape": (), "dynamic": True},
"scale": {"units": "arcsec/grid-unit", "shape": (), "dynamic": True},
}
usable = True
def __init__(self, *args, nodes: Tuple[int, int] = (3, 3), **kwargs):
"""Initialize the BilinearSky model with a grid of nodes."""
super().__init__(*args, **kwargs)
self.nodes = nodes
[docs]
@torch.no_grad()
@ignore_numpy_warnings
def initialize(self):
super().initialize()
if self.I.initialized:
self.nodes = tuple(self.I.value.shape)
if not self.PA.initialized:
R, _ = polar_decomposition(self.target.CD.npvalue)
self.PA.value = np.arccos(np.abs(R[0, 0]))
if not self.scale.initialized:
self.scale.value = (
self.target.pixelscale.item() * self.target._data.shape[0] / self.nodes[0]
)
if self.I.initialized:
return
target_dat = self.target[self.window]
dat = backend.to_numpy(target_dat._data).copy()
mask = backend.to_numpy(target_dat._mask).copy()
dat[mask] = np.nanmedian(dat)
iS = dat.shape[0] // self.nodes[0]
jS = dat.shape[1] // self.nodes[1]
self.I.value = (
np.median(
dat[: iS * self.nodes[0], : jS * self.nodes[1]].reshape(
iS, self.nodes[0], jS, self.nodes[1]
),
axis=(0, 2),
)
/ self.target.pixel_area.item()
)
[docs]
@forward
def brightness(self, x: ArrayLike, y: ArrayLike, I: ArrayLike) -> ArrayLike:
x, y = self.transform_coordinates(x, y)
return interp2d(I, x, y)