Group Models#

Here you will learn how to combine models together into a larger, more complete, model of a given system. This is a powerful and necessary capability when analysing objects in crowded environments. As telescopes achieve ever deeper photometry we have learned that all environments are crowded when projected onto the sky!

import astrophot as ap
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
# first let's download an image to play with
############ UNCOMMENT IF RUNNING LOCALLY ############
# hdu = fits.open(
#     "https://www.legacysurvey.org/viewer/fits-cutout?ra=155.7720&dec=15.1494&size=150&layer=ls-dr9&pixscale=0.262&bands=r"
# )
# hdu.writeto("group_target_image.fits", overwrite=True)
hdu = fits.open("group_target_image.fits")
# hdu = ap.utils.ls_open(155.7720, 15.1494, 150 * 0.262, band="r")
target_data = np.array(hdu[0].data, dtype=np.float64)
fig1, ax1 = plt.subplots(figsize=(8, 8))
plt.imshow(np.arctan(target_data / 0.05), origin="lower", cmap="inferno")
plt.axis("off")
plt.show()
../_images/4f4d65efa10bc533fc2a8983dc4717925e7b3f95a34f3b22eb9d69a4bd10f58f.png
#########################################
# NOTE: photutils is not a dependency of AstroPhot, make sure you run: pip install photutils
# if you dont already have that package. Also note that you can use any segmentation map
# code, we just use photutils here because it is very easy.
#########################################
from photutils.segmentation import detect_sources, deblend_sources

initsegmap = detect_sources(target_data, threshold=0.02, npixels=6)
segmap = deblend_sources(target_data, initsegmap, npixels=5).data
fig8, ax8 = plt.subplots(figsize=(8, 8))
ax8.imshow(segmap, origin="lower", cmap="inferno")
plt.show()
../_images/7a01465e65b8b978e54f77dfc393b7ffc2442ba8f880cec98a8498d5d3362582.png
pixelscale = 0.262
target = ap.TargetImage(
    data=target_data + 0.01,  # add fake sky level back in
    pixelscale=pixelscale,
    zeropoint=22.5,
    variance="auto",  # this will estimate the variance from the data
)
fig2, ax2 = plt.subplots(figsize=(8, 8))
ap.plots.target_image(fig2, ax2, target)
plt.show()
../_images/7e834fb850672e2eababb02f5f1bf4880749549b027118769d1a3c165f7d5892.png

Group Model#

A group model takes a list of other AstroPhot_Model objects and tracks them such that they can be treated as a single larger model. When “initialize” is called on the group model, it simply calls “initialize” on all the individual models. The same is true for a number of other functions. For fitting, however, the group model will collect the parameters from all the models together and pass them along as one group to the optimizer. When saving a group model, all the model states will be collected together into one large file.

The main difference when constructing a group model is that you must first create all the sub models that will go in it. Once constructed, a group model behaves just like any other model, in fact they are all built from the same base class.

# This will convert the segmentation map into boxes that enclose the identified pixels
windows = ap.utils.initialize.windows_from_segmentation_map(segmap)
# Next we scale up the windows so that AstroPhot can fit the faint parts of each object as well
windows = ap.utils.initialize.scale_windows(windows, image=target, expand_scale=2, expand_border=10)
# Here we get some basic starting parameters for the galaxies (center, position angle, axis ratio)
centers = ap.utils.initialize.centroids_from_segmentation_map(segmap, target)
PAs = ap.utils.initialize.PA_from_segmentation_map(segmap, target, centers)
qs = ap.utils.initialize.q_from_segmentation_map(segmap, target, centers)
# Now we use all the windows to add to the list of models
seg_models = []
for win in windows:
    seg_models.append(
        ap.Model(
            name=f"object_{win:02d}",
            window=windows[win],
            model_type="sersic galaxy model",
            target=target,
            center=centers[win],
            PA=PAs[win],
            q=qs[win],
        )
    )
sky = ap.Model(
    name=f"sky_level",
    model_type="flat sky model",
    target=target,
    I0={"valid": (0, None)},
)

# We build the group model just like any other, except we pass a list of other models
groupmodel = ap.Model(
    name="group", models=[sky] + seg_models, target=target, model_type="group model"
)

groupmodel.initialize()
Initializing model sky_level
Initializing model object_01
Initializing model object_02
Initializing model object_03
Initializing model object_04
Initializing model object_05
Initializing model object_06
Initializing model object_07
Initializing model object_08
import torch

x = groupmodel.get_values()
x = x.repeat(5, 1)
imgs = torch.vmap(lambda x: groupmodel(x).data)(x)
print(imgs.shape)
torch.Size([5, 150, 150])
fig, ax = plt.subplots(1, 2, figsize=(18, 8))
ap.plots.target_image(fig, ax[0], groupmodel.target)
ap.plots.model_window(fig, ax[0], groupmodel)
ax[0].set_title("Sub model fitting windows")
ap.plots.model_image(fig, ax[1], groupmodel)
ax[1].set_title("auto initialized parameters")
plt.show()
../_images/ac37e568aff1b76445156baeeaf7a54f620a8f268232e84489153f9afcaa0d1a.png
# This is now a very complex model composed of 9 sub-models! In total 57 parameters!
# Here we will limit it to 1 iteration so that it runs quickly. In general you should let it run to convergence
result = ap.fit.Iter(groupmodel, verbose=1, max_iter=2).fit()
result = ap.fit.LM(groupmodel, verbose=0, max_iter=2).fit()
--------iter-------
sky_level
/home/docs/checkouts/readthedocs.org/user_builds/astrophot/envs/v0.17.3/lib/python3.12/site-packages/torch/jit/_script.py:1488: DeprecationWarning: `torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`.
  warnings.warn(
object_01
object_02
object_03
object_04
object_05
object_06
object_07
object_08
Update Chi^2 with new parameters
Loss: 1.6905223992071536
--------iter-------
sky_level
object_01
object_02
object_03
object_04
object_05
object_06
object_07
object_08
Update Chi^2 with new parameters
Loss: 1.6798977180632828
# Now we can see what the fitting has produced
fig10, ax10 = plt.subplots(1, 2, figsize=(16, 7))
ap.plots.model_image(fig10, ax10[0], groupmodel, vmax=25)
ap.plots.residual_image(fig10, ax10[1], groupmodel, normalize_residuals=True)
plt.show()
../_images/713d0135b68f5e35d56f7bd81512110fe857e56797dfcb78834ef077fbf6ad50.png

Which is a pretty good fit! We haven’t accounted for the PSF yet, so some of the central regions are not very well fit. It is very easy to add a PSF model to AstroPhot for fitting. Check out the Basic PSF Models tutorial for more information.

Segmentation maps#

AstroPhot can produce a model based segmentation map. Essentially, once the models are fit it can compute the “importance” of each pixel to a given model. For each pixel and for each model it is possible to compute what fraction of the model’s total flux is placed in that pixel. Whichever model assigns the highest fraction of all its flux to a given pixel, is the “winner” for that pixel and so the segmentation map assigns the pixel to its index. Note that this is only done at the first level of a group model, since group models can contain group models, it is possible to have a complex multi-component model still act as one index in the segmentation map.

Also note that this means AstroPhot can perform segmentation even for images with non-zero sky levels, there is no need to do background subtraction before segmenting (though you do need to fit the models).

plt.imshow(groupmodel.segmentation_map(), origin="lower", cmap="inferno")
plt.show()
../_images/41afe503b5a4286aa9564e8f23a5a3dd422906c647fc8dc7d8db43f87fe80191.png

Deblending#

AstroPhot can perform a basic deblending based on the fitted model. A new target image is created for each object which for each pixel holds the fraction of signal from the original target corresponding to the fraction of light coming from that individual model (compared to the full group model). This can create some patches of zero pixel values where the model falls to zero in its own window, or where other models are much brighter.

Note that this works even when the sky level is not subtracted. Though for very bright sky levels, the deblended objects tend to just look like their model images.

AstroPhot doesn’t use deblending, it’s forward modelling approach means that it simultaneously models all objects using a principled Gaussian (or Poisson) likelihood. That said, other analyses may make use of deblended stamps. It is also a good systematic check of the flux estimates. A flux estimate that varies wildly from the deblend total flux might be cause for concern.

subtargets = groupmodel.deblend()
fig, axarr = plt.subplots(2, int(np.ceil(len(subtargets) / 2)), figsize=(16, 7))
for i, subtarget in enumerate(subtargets):
    ax = axarr.flatten()[i]
    ap.plots.target_image(fig, ax, subtarget)
    ax.set_title(subtarget.name, fontsize=10)
    ax.axis("off")
axarr.flatten()[-1].axis("off")
plt.show()

for submodel, subtarget in zip(groupmodel.models, subtargets):
    print(
        f"{submodel.name}: total model flux = {submodel.total_flux().item():.2f} ± {submodel.total_flux_uncertainty().item():.2f}, deblend total flux = {subtarget.data.sum().item():.2f}"
    )
../_images/7e1340e293dd204fd125176800cfa88a2347c60af2135f8f2d774092e7ab5d4e.png
sky_level: total model flux = 228.04 ± 0.65, deblend total flux = 239.35
object_01: total model flux = 20.11 ± 0.43, deblend total flux = 20.30
object_02: total model flux = 7.31 ± 0.23, deblend total flux = 7.37
object_03: total model flux = 23.78 ± 1.01, deblend total flux = 24.20
object_04: total model flux = 26.75 ± 0.56, deblend total flux = 27.01
object_05: total model flux = 9.34 ± 0.35, deblend total flux = 9.44
object_06: total model flux = 56.10 ± 1.16, deblend total flux = 56.90
object_07: total model flux = 2.58 ± 0.27, deblend total flux = 2.62
object_08: total model flux = 2.19 ± 0.92, deblend total flux = 2.26

Observe that all the models (except the sky, which we fudged anyway) are within one sigma between the model flux and the deblended flux. This is a good sign! If there had been any major deviations that would be very suspicious.