Source code for astrophot.models.func.integration

from typing import Tuple
import numpy as np

from ...utils.integration import quad_table
from ...backend_obj import backend, ArrayLike
from ... import config


[docs] def pixel_center_integrator(Z: ArrayLike) -> ArrayLike: return Z
[docs] def pixel_simpsons_integrator(Z: ArrayLike) -> ArrayLike: kernel = ( backend.as_array( [[[[1, 4, 1], [4, 16, 4], [1, 4, 1]]]], dtype=config.DTYPE, device=config.DEVICE ) / 36.0 ) Z = backend.conv2d(Z.reshape(1, 1, *Z.shape), kernel, padding="valid", stride=2) return Z.squeeze(0).squeeze(0)
[docs] def pixel_quad_integrator(Z: ArrayLike, w: ArrayLike = None, order: int = 3) -> ArrayLike: """ Integrate the pixel values using quadrature weights. **Args:** - `Z`: The tensor containing pixel values. - `w`: The quadrature weights. - `order`: The order of the quadrature. """ if w is None: _, _, w = quad_table(order, config.DTYPE, config.DEVICE) Z = Z * w return backend.sum(Z, dim=-1)
[docs] def upsample(i: ArrayLike, j: ArrayLike, order: int, scale: float) -> Tuple[ArrayLike, ArrayLike]: dp = ( backend.linspace(-1, 1, order, dtype=config.DTYPE, device=config.DEVICE) * (order - 1) / (2.0 * order) ) di, dj = backend.meshgrid(dp, dp, indexing="xy") si = backend.repeat(i[..., None], order**2, -1) + scale * di.flatten() sj = backend.repeat(j[..., None], order**2, -1) + scale * dj.flatten() return si, sj
[docs] def single_quad_integrate( i: ArrayLike, j: ArrayLike, brightness_ij, scale: float, quad_order: int = 3 ) -> Tuple[ArrayLike, ArrayLike]: di, dj, w = quad_table(quad_order, config.DTYPE, config.DEVICE) qi = backend.repeat(i[..., None], quad_order**2, -1) + scale * di.flatten() qj = backend.repeat(j[..., None], quad_order**2, -1) + scale * dj.flatten() z = brightness_ij(qi, qj) z0 = backend.mean(z, dim=-1) z = backend.sum(z * w.flatten(), dim=-1) return z, z0
[docs] def recursive_quad_integrate( i: ArrayLike, j: ArrayLike, brightness_ij: callable, curve_frac: float, scale: float = 1.0, quad_order: int = 3, gridding: int = 5, _current_depth: int = 0, max_depth: int = 1, ) -> ArrayLike: z, z0 = single_quad_integrate(i, j, brightness_ij, scale, quad_order) if _current_depth >= max_depth: return z N = max(1, int(np.prod(z.shape) * curve_frac)) select = backend.topk(backend.abs(z - z0).flatten(), N)[1] integral_flat = z.flatten() si, sj = upsample(i.flatten()[select], j.flatten()[select], gridding, scale) integral_flat = backend.fill_at_indices( integral_flat, select, backend.mean( recursive_quad_integrate( si, sj, brightness_ij, curve_frac=curve_frac, scale=scale / gridding, quad_order=quad_order, gridding=gridding, _current_depth=_current_depth + 1, max_depth=max_depth, ), dim=-1, ), ) return integral_flat.reshape(z.shape)
[docs] def bright_integrate( z: ArrayLike, i: ArrayLike, j: ArrayLike, brightness_ij: callable, bright_frac: float, scale: float = 1.0, quad_order: int = 3, gridding: int = 5, max_depth: int = 2, ): trace = [] for d in range(max_depth): N = max(1, int(np.prod(z.shape) * bright_frac)) z_flat = z.flatten() select = backend.topk(z_flat, N)[1] trace.append([z_flat, select, z.shape]) if d > 0: i, j = upsample(i.flatten()[select], j.flatten()[select], gridding, scale) scale = scale / gridding else: i, j = i.flatten()[select].reshape(-1, 1), j.flatten()[select].reshape(-1, 1) z, _ = single_quad_integrate(i, j, brightness_ij, scale, quad_order) trace.append([z, None, z.shape]) for _ in reversed(range(1, max_depth + 1)): T = trace.pop(-1) trace[-1][0] = backend.fill_at_indices( trace[-1][0], trace[-1][1], backend.mean(T[0].reshape(T[2]), dim=-1) ) return trace[0][0].reshape(trace[0][2])