Source code for astrophot.models.func.spline
from ...backend_obj import backend, ArrayLike
from ... import config
def _h_poly(t: ArrayLike) -> ArrayLike:
"""Helper function to compute the 'h' polynomial matrix used in the
cubic spline.
Args:
t (Tensor): A 1D tensor representing the normalized x values.
Returns:
Tensor: A 2D tensor of size (4, len(t)) representing the 'h' polynomial matrix.
"""
tt = t[None, :] ** (backend.arange(4, device=config.DEVICE)[:, None])
A = backend.as_array(
[
[1.0, 0.0, -3.0, 2.0],
[0.0, 1.0, -2.0, 1.0],
[0.0, 0.0, 3.0, -2.0],
[0.0, 0.0, -1.0, 1.0],
],
dtype=config.DTYPE,
device=config.DEVICE,
)
return A @ tt
[docs]
def cubic_spline_torch(x: ArrayLike, y: ArrayLike, xs: ArrayLike) -> ArrayLike:
"""Compute the 1D cubic spline interpolation for the given data points
using PyTorch.
**Args:**
- `x` (Tensor): A 1D tensor representing the x-coordinates of the known data points.
- `y` (Tensor): A 1D tensor representing the y-coordinates of the known data points.
- `xs` (Tensor): A 1D tensor representing the x-coordinates of the positions where
the cubic spline function should be evaluated.
"""
m = (y[1:] - y[:-1]) / (x[1:] - x[:-1])
m = backend.concatenate([m[0].flatten(), (m[1:] + m[:-1]) / 2, m[-1].flatten()])
idxs = backend.searchsorted(x[:-1], xs) - 1
dx = x[idxs + 1] - x[idxs]
hh = _h_poly((xs - x[idxs]) / dx)
ret = hh[0] * y[idxs] + hh[1] * m[idxs] * dx + hh[2] * y[idxs + 1] + hh[3] * m[idxs + 1] * dx
return ret
[docs]
def spline(R: ArrayLike, profR: ArrayLike, profI: ArrayLike, extend: str = "zeros") -> ArrayLike:
"""Spline 1d profile function, cubic spline between points up
to second last point beyond which is linear
**Args:**
- `R`: Radii tensor at which to evaluate the spline function
- `profR`: radius values for the surface density profile in the same units as `R`
- `profI`: surface density values for the surface density profile
- `extend`: How to extend the spline beyond the last point. Options are 'zeros' or 'const'.
"""
I = cubic_spline_torch(profR, profI, R.flatten()).reshape(*R.shape)
if extend == "zeros":
I = backend.where(R > profR[-1], backend.zeros_like(I), I)
elif extend == "const":
I = backend.where(R > profR[-1], profI[-1] * backend.ones_like(I), I)
else:
raise ValueError(f"Unknown extend option: {extend}. Use 'zeros' or 'const'.")
return I