Source code for astrophot.models.mixins.sample

from typing import Optional, Literal

import numpy as np

from ...param import forward
from ...backend_obj import backend, ArrayLike
from ... import config
from ...image import JacobianImage
from .. import func
from ...errors import SpecificationConflict
from ...utils.integration import quad_table


[docs] class SampleMixin: """ Methods for integrating the model from a smooth model defined in the tangent plane into individual pixel fluxes. This is done in a two step process. First the model is sampled at a set of points within each pixel, and then an adaptive integration method is used to further integrate pixels where it has identified the need for additional accuracy. The `sampling_mode` option controls this first step. It determines at what level of depth every pixel is integrated. The midpoint option is the least accurate (and fastest) which just samples the center of each pixel. After that, each method trades more compute for more accuracy. The `quad:x` method is the most accurate, which uses Gaussian quadrature integration with x points per pixel. Note that `quad:5` means that each pixel will be sampled at 25 points (5^2) to determine the flux in that pixel. `simpsons` is often a good middle ground. Note that for models over a small number of pixels you will likely not notice the runtime difference between midpoint and some higher accuracy method, since other aspects of the fitting process also take up some time. The `integrate_mode` option controls the second step, which is an adaptive integration method that identifies and integrates pixels where the model needs extra accuracy. The default method is `bright`, which identifies the brightest pixels and then uses quadrature integration to further integrate those pixels. The default parameters are to recursively integrate the brightest 5% of pixels up to a maximum depth of 2 recursive levels. Each level does a 5x upsampling and then uses 3rd order quadrature to integrate the super sampled pixels. This means that the most highly integrated pixels will be 5x upsampled twice and the 3x sampled for the quadrature, effectively like upsampling 75 times the starting resolution for those pixels, but only for 0.25% of the pixels. Doing this roughly doubles the amount of compute needed to sample an image relative to midpoint sampling, but gives a massive boost in accuracy for models which change rapidly across a pixel. Note: JAX does not play nicely with the adaptive integration methods, so it massively slows down the jit compilation and the final sampling speed. With JAX it is generally better to set `integrate_mode` to `none` and use a higher accuracy `sampling_mode` such as `quad:5`. :param sampling_mode: The method used to sample the model in image pixels. Options are: `auto`: Automatically choose the sampling method based on the image size (default). `midpoint`: Use midpoint sampling, evaluate the brightness at the center of each pixel. `simpsons`: Use Simpson's rule for sampling integrating each pixel. `upsample:x` upsample the pixel in a regular grid of size x (odd positive integer), generally less accurate than quad:x. `quad:x`: Use quadrature sampling with order x, where x is an odd positive integer to integrate each pixel. :param integrate_mode: The method used to select pixels to integrate further where the model varies significantly. Options are: `none`: No extra integration is performed (beyond the sampling_mode). `bright`: Select the brightest pixels for further integration (default). `curvature`: Select pixels which show signs of significant higher order derivatives. :param integrate_fraction: The fraction of the pixels to super sample during integration (default: 0.05). :param integrate_max_depth: The maximum depth of the integration method (default: 2). :param integrate_gridding: The gridding used for the integration method to super-sample a pixel at each iteration (default: 5). :param integrate_quad_order: The order of the quadrature used for the integration method on the super sampled pixels (default: 3). """ integrate_fraction = 0.05 # fraction of the pixels to super sample integrate_max_depth = 2 integrate_gridding = 5 integrate_quad_order = 3 _options = ( "sampling_mode", "integrate_mode", "integrate_fraction", "integrate_max_depth", "integrate_gridding", "integrate_quad_order", ) def __init__(self, *args, sampling_mode="auto", integrate_mode="bright", **kwargs): super().__init__(*args, **kwargs) self.sampling_mode = sampling_mode self.integrate_mode = integrate_mode @forward def _bright_integrate( self, sample: ArrayLike, i: ArrayLike, j: ArrayLike, upsample: int, pixel_brightness: callable, ) -> ArrayLike: sample = func.bright_integrate( sample, i, j, pixel_brightness, scale=self.target.base_scale / upsample, bright_frac=self.integrate_fraction, quad_order=self.integrate_quad_order, gridding=self.integrate_gridding, max_depth=self.integrate_max_depth, ) return sample @forward def _curvature_integrate( self, sample: ArrayLike, i: ArrayLike, j: ArrayLike, upsample: int, pixel_brightness: callable, ) -> ArrayLike: kernel = func.curvature_kernel(config.DTYPE, config.DEVICE) curvature = ( backend.abs( backend.pad( backend.conv2d( sample.reshape(1, 1, *sample.shape), kernel.reshape(1, 1, *kernel.shape), padding="valid", ), (0, 0, 0, 0, 1, 1, 1, 1), mode="replicate", ) ) .squeeze(0) .squeeze(0) ) N = max(1, int(np.prod(i.shape) * self.integrate_fraction)) select = backend.topk(curvature.flatten(), N)[1] sample_flat = sample.flatten() sample_flat = backend.fill_at_indices( sample_flat, select, func.recursive_quad_integrate( i.flatten()[select], j.flatten()[select], pixel_brightness, scale=self.target.base_scale / upsample, curve_frac=self.integrate_fraction, quad_order=self.integrate_quad_order, gridding=self.integrate_gridding, max_depth=self.integrate_max_depth, ), ) return sample_flat.reshape(sample.shape) @property def sampling_mode(self): return self._sampling_mode @sampling_mode.setter def sampling_mode(self, sampling_mode): if sampling_mode == "auto": sampling_mode = "midpoint" try: N = np.prod(self.window.shape) if N <= 100: sampling_mode = "quad:5" elif N <= 10000: sampling_mode = "simpsons" except: pass if sampling_mode == "midpoint": self._pixel_meshgridder = lambda im, w, p, u: im.pixel_center_meshgrid(w, p, u) self._pixel_integrator = func.pixel_center_integrator self._pixel_center_finder = lambda i, j: (i, j) elif sampling_mode == "simpsons": self._pixel_meshgridder = lambda im, w, p, u: im.pixel_simpsons_meshgrid(w, p, u) self._pixel_integrator = func.pixel_simpsons_integrator self._pixel_center_finder = lambda i, j: (i[1:-1:2, 1:-1:2], j[1:-1:2, 1:-1:2]) elif sampling_mode.startswith("quad:"): order = int(sampling_mode.split(":")[1]) self._pixel_meshgridder = lambda im, w, p, u: im.pixel_quad_meshgrid( w, p, u, order=order )[:2] _, _, w = quad_table(order, config.DTYPE, config.DEVICE) w = w.flatten() self._pixel_integrator = lambda z: func.pixel_quad_integrator(z, w) self._pixel_center_finder = lambda i, j: (i[..., order**2 // 2], j[..., order**2 // 2]) elif sampling_mode.startswith("upsample:"): upsample = int(sampling_mode.split(":")[1]) if upsample % 2 != 1: raise SpecificationConflict( f"Upsample factor for 'sample_mode' must be an odd integer, got {upsample} for model {self.name}" ) self._pixel_meshgridder = lambda im, w, p, u: im.pixel_center_meshgrid( w, upsample * p, upsample * u ) self._pixel_integrator = lambda z: func.downsample_mean(z, upsample) self._pixel_center_finder = lambda i, j: ( i[upsample // 2 :: upsample, upsample // 2 :: upsample], j[upsample // 2 :: upsample, upsample // 2 :: upsample], ) else: raise SpecificationConflict( f"Unknown sampling mode {sampling_mode} for model {self.name}" ) self._sampling_mode = sampling_mode @property def integrate_mode(self): return self._integrate_mode @integrate_mode.setter def integrate_mode(self, integrate_mode): if integrate_mode == "bright": self._adaptive_integrator = self._bright_integrate elif integrate_mode == "curvature": self._adaptive_integrator = self._curvature_integrate elif integrate_mode == "none": self._adaptive_integrator = lambda z, *a, **kw: z else: raise SpecificationConflict( f"Unknown integrate mode {integrate_mode} for model {self.name}" ) self._integrate_mode = integrate_mode
[docs] class GradMixin: """ :param jacobian_maxparams: The maximum number of parameters before the Jacobian will be broken into smaller chunks to reduce memory consumption (int, default: 10). """ # Maximum size of parameter list before jacobian will be broken into smaller chunks, this is helpful for limiting the memory requirements to build a model, lower jacobian_chunksize is slower but uses less memory jacobian_maxparams = 10 _options = ("jacobian_maxparams",) def _jacobian( self, params_pre: ArrayLike, params: ArrayLike, params_post: ArrayLike, ) -> ArrayLike: # return jacfwd( # this should be more efficient, but the trace overhead is too high # lambda x: self.sample( # window=window, params=torch.cat((params_pre, x, params_post), dim=-1) # ).data # )(params) return backend.jacobian( lambda x: self(params=backend.concatenate((params_pre, x, params_post), dim=-1))._data, params, )
[docs] def jacobian( self, pass_jacobian: Optional[JacobianImage] = None, params: Optional[ArrayLike] = None, ) -> JacobianImage: if pass_jacobian is None: jac_img = self.target[self.window].jacobian_image( parameters=self.build_params_array_identities() ) else: jac_img = pass_jacobian # No dynamic params if params is None: params = self.get_values() if len(params.shape) == 0 or params.shape[-1] == 0: return jac_img identities = self.build_params_array_identities() if len(jac_img.match_parameters(identities)[0]) == 0: return jac_img target = self.target[self.window] if len(params) > self.jacobian_maxparams: # handle large number of parameters chunksize = len(params) // self.jacobian_maxparams + 1 for i in range(0, len(params), chunksize): params_pre = params[:i] params_chunk = params[i : i + chunksize] params_post = params[i + chunksize :] jac_chunk = self._jacobian(params_pre, params_chunk, params_post) jac_img += target.jacobian_image( parameters=identities[i : i + chunksize], data=jac_chunk, ) else: jac = self._jacobian(params[:0], params, params[0:0]) jac_img += target.jacobian_image(parameters=identities, data=jac) return jac_img
[docs] def gradient( self, params: Optional[ArrayLike] = None, likelihood: Literal["gaussian", "poisson"] = "gaussian", ) -> ArrayLike: """Compute the gradient of the model likelihood with respect to its parameters.""" jacobian_image = self.jacobian(params=params).flatten("data") data = self.target[self.window].flatten("data") mask = self.target[self.window].flatten("mask") model = self().flatten("data") if likelihood == "gaussian": weight = self.target[self.window].flatten("weight") gradient = backend.sum( jacobian_image * ((data - model) * weight * (~mask))[..., None], dim=0 ) elif likelihood == "poisson": gradient = backend.sum( jacobian_image * ((1 - data / model) * (~mask))[..., None], dim=0, ) return gradient