from typing import Optional, Tuple
import torch
import numpy as np
from .model_object import ComponentModel
from ..utils.decorators import ignore_numpy_warnings, combine_docstrings
from . import func
from .. import config
from ..backend_obj import backend, ArrayLike
from ..param import forward
__all__ = ["MultiGaussianExpansion"]
[docs]
@combine_docstrings
class MultiGaussianExpansion(ComponentModel):
"""Model that represents a galaxy as a sum of multiple Gaussian
profiles. The model is defined as:
.. math::
I(R) = \\sum_i {\\rm flux}_i * \\exp(-0.5*(R_i / \\sigma_i)^2) / (2 * \\pi * q_i * \\sigma_i^2)
where :math:`R_i` is a radius computed using :math:`q_i` and :math:`PA_i` for that component. All components share the same center.
:param q: axis ratio to scale minor axis from the ratio of the minor/major axis b/a, this parameter is unitless, it is restricted to the range (0,1)
:param PA: position angle of the semi-major axis East of North, it is a cyclic parameter in the range [0,pi)
:param sigma: standard deviation of each Gaussian
:param flux: amplitude of each Gaussian
"""
_model_type = "mge"
_parameter_specs = {
"q": {
"units": "b/a",
"valid": (0, 1),
"shape": (None,),
"dynamic": True,
"description": "axis ratio to scale minor axis from the ratio of the minor/major axis b/a",
},
"PA": {
"units": "radians",
"valid": (0, np.pi),
"cyclic": True,
"dynamic": True,
"description": "position angle of the semi-major axis East of North, it is a cyclic parameter in the range [0,pi)",
}, # No shape for PA since there are two options, use with caution
"sigma": {
"units": "arcsec",
"valid": (0, None),
"shape": (None,),
"dynamic": True,
"description": "standard deviation of each Gaussian",
},
"flux": {
"units": "flux",
"shape": (None,),
"dynamic": True,
"description": "amplitude of each Gaussian",
},
}
usable = True
def __init__(self, *args, n_components: Optional[int] = None, **kwargs):
super().__init__(*args, **kwargs)
if n_components is None:
for key in ("q", "sigma", "flux"):
if self[key].value is not None:
self.n_components = self[key].value.shape[0]
break
else:
self.n_components = 1
else:
self.n_components = int(n_components)
[docs]
@torch.no_grad()
@ignore_numpy_warnings
def initialize(self):
super().initialize()
target_area = self.target[self.window]
dat = backend.to_numpy(target_area._data).copy()
mask = backend.to_numpy(target_area._mask)
dat[mask] = np.median(dat[~mask])
edge = np.concatenate((dat[:, 0], dat[:, -1], dat[0, :], dat[-1, :]))
edge_average = np.nanmedian(edge)
dat -= edge_average
if not self.sigma.initialized:
self.sigma.value = np.logspace(
np.log10(target_area.pixelscale.item() * 3),
max(target_area.data.shape) * target_area.pixelscale.item() * 0.7,
self.n_components,
)
if not self.flux.initialized:
self.flux.value = (np.sum(dat) / self.n_components) * np.ones(self.n_components)
if self.PA.initialized and self.q.initialized:
return
x, y = target_area.coordinate_center_meshgrid()
x = backend.to_numpy(x - self.center.value[0])
y = backend.to_numpy(y - self.center.value[1])
mu20 = np.median(dat * np.abs(x))
mu02 = np.median(dat * np.abs(y))
mu11 = np.median(dat * x * y / np.sqrt(np.abs(x * y) + self.softening**2))
# mu20 = np.median(dat * x**2)
# mu02 = np.median(dat * y**2)
# mu11 = np.median(dat * x * y)
M = np.array([[mu20, mu11], [mu11, mu02]])
ones = np.ones(self.n_components)
if not self.PA.initialized:
if np.any(np.iscomplex(M)) or np.any(~np.isfinite(M)):
self.PA.value = ones * np.pi / 2
else:
self.PA.value = ones * (0.5 * np.arctan2(2 * mu11, mu20 - mu02) - np.pi / 2) % np.pi
if not self.q.initialized:
l = np.sort(np.linalg.eigvals(M))
if np.any(np.iscomplex(l)) or np.any(~np.isfinite(l)):
l = (0.7, 1.0)
self.q.value = ones * np.clip(np.sqrt(l[0] / l[1]), 0.1, 0.9)
[docs]
@forward
def brightness(
self, x: ArrayLike, y: ArrayLike, flux: ArrayLike, sigma: ArrayLike, q: ArrayLike
) -> ArrayLike:
x, y = self.transform_coordinates(x, y)
R = self.radius_metric(x, y)
return backend.sum(
backend.vmap(
lambda A, r, sig, _q: (A / (2 * np.pi * _q * sig**2))
* backend.exp(-0.5 * (r / sig) ** 2)
)(flux, R, sigma, q),
dim=0,
)