Source code for astrophot.backend_obj

import os
import importlib
from typing import Annotated

from torch import Tensor, dtype, device
import torch
import numpy as np
import caskade as ck

from . import config

ArrayLike = Annotated[
    Tensor,
    "One of: torch.Tensor or jax.numpy.ndarray depending on the chosen backend.",
]
dtypeLike = Annotated[
    dtype,
    "One of: torch.dtype or jax.numpy.dtype depending on the chosen backend.",
]
deviceLike = Annotated[
    device,
    "One of: torch.device or jax.DeviceArray depending on the chosen backend.",
]


[docs] class Backend: def __init__(self, backend=None): self.backend = backend @property def backend(self): return self._backend @backend.setter def backend(self, backend): if backend is None: backend = os.getenv("CASKADE_BACKEND", "torch") ck.backend.backend = backend self._load_backend(backend) self._backend = backend def _load_backend(self, backend): if backend == "torch": self.module = importlib.import_module("torch") self.setup_torch() elif backend == "jax": self.module = importlib.import_module("jax.numpy") self.setup_jax() else: raise ValueError(f"Unsupported backend: {backend}")
[docs] def setup_torch(self): config.DTYPE = torch.float64 config.DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" self.make_array = self._make_array_torch self._array_type = self._array_type_torch self.concatenate = self._concatenate_torch self.copy = self._copy_torch self.tolist = self._tolist_torch self.view = self._view_torch self.as_array = self._as_array_torch self.to = self._to_torch self.to_numpy = self._to_numpy_torch self.gammaln = self._gammaln_torch self.logit = self._logit_torch self.sigmoid = self._sigmoid_torch self.repeat = self._repeat_torch self.stack = self._stack_torch self.transpose = self._transpose_torch self.upsample2d = self._upsample2d_torch self.pad = self._pad_torch self.LinAlgErr = self.module._C._LinAlgError self.roll = self._roll_torch self.clamp = self._clamp_torch self.flatten = self._flatten_torch self.conv2d = self._conv2d_torch self.mean = self._mean_torch self.sum = self._sum_torch self.max = self._max_torch self.topk = self._topk_torch self.bessel_j1 = self._bessel_j1_torch self.bessel_k1 = self._bessel_k1_torch self.lgamma = self._lgamma_torch self.hessian = self._hessian_torch self.jacobian = self._jacobian_torch self.jacfwd = self._jacfwd_torch self.grad = self._grad_torch self.vmap = self._vmap_torch self.long = self._long_torch self.detach = lambda x: x.detach() self.fill_at_indices = self._fill_at_indices_torch self.add_at_indices = self._add_at_indices_torch
[docs] def setup_jax(self): self.jax = importlib.import_module("jax") self.jax.config.update("jax_enable_x64", True) config.DTYPE = None config.DEVICE = None self.make_array = self._make_array_jax self._array_type = self._array_type_jax self.concatenate = self._concatenate_jax self.copy = self._copy_jax self.tolist = self._tolist_jax self.view = self._view_jax self.as_array = self._as_array_jax self.to = self._to_jax self.to_numpy = self._to_numpy_jax self.gammaln = self._gammaln_jax self.logit = self._logit_jax self.sigmoid = self._sigmoid_jax self.repeat = self._repeat_jax self.stack = self._stack_jax self.transpose = self._transpose_jax self.upsample2d = self._upsample2d_jax self.pad = self._pad_jax self.LinAlgErr = Exception self.roll = self._roll_jax self.clamp = self._clamp_jax self.flatten = self._flatten_jax self.conv2d = self._conv2d_jax self.mean = self._mean_jax self.sum = self._sum_jax self.max = self._max_jax self.topk = self._topk_jax self.bessel_j1 = self._bessel_j1_jax self.bessel_k1 = self._bessel_k1_jax self.lgamma = self._lgamma_jax self.hessian = self._hessian_jax self.jacobian = self._jacobian_jax self.jacfwd = self._jacfwd_jax self.grad = self._grad_jax self.vmap = self._vmap_jax self.long = self._long_jax self.detach = lambda x: x self.fill_at_indices = self._fill_at_indices_jax self.add_at_indices = self._add_at_indices_jax
@property def array_type(self): return self._array_type() def _make_array_torch(self, array, dtype=None, device=None): return self.module.tensor(array, dtype=dtype, device=device) def _make_array_jax(self, array, dtype=None, **kwargs): return self.module.array(array, dtype=dtype) def _array_type_torch(self): return self.module.Tensor def _array_type_jax(self): return self.module.ndarray def _concatenate_torch(self, arrays, dim=0): return self.module.cat(arrays, dim=dim) def _concatenate_jax(self, arrays, dim=0): return self.module.concatenate(arrays, axis=dim) def _copy_torch(self, array): return array.detach().clone() def _copy_jax(self, array): return self.module.copy(array) def _tolist_torch(self, array): return array.detach().cpu().tolist() def _tolist_jax(self, array): return array.block_until_ready().tolist() def _view_torch(self, array, shape): return array.reshape(shape) def _view_jax(self, array, shape): return array.reshape(shape) def _as_array_torch(self, array, dtype=None, device=None): return self.module.as_tensor(array, dtype=dtype, device=device) def _as_array_jax(self, array, dtype=None, **kwargs): return self.module.asarray(array, dtype=dtype) def _to_torch(self, array, dtype=None, device=None): return array.to(dtype=dtype, device=device) def _to_jax(self, array, dtype=None, device=None): return self.jax.device_put(array.astype(dtype), device=device) def _to_numpy_torch(self, array): return array.detach().cpu().numpy() def _to_numpy_jax(self, array): return np.array(array.block_until_ready()) def _repeat_torch(self, a, repeats, axis=None): return self.module.repeat_interleave(a, repeats, dim=axis) def _repeat_jax(self, a, repeats, axis=None): return self.module.repeat(a, repeats, axis=axis) def _stack_torch(self, arrays, dim=0): return self.module.stack(arrays, dim=dim) def _stack_jax(self, arrays, dim=0): return self.module.stack(arrays, axis=dim) def _transpose_torch(self, array, *args): return self.module.transpose(array, *args) def _transpose_jax(self, array, *args): permutation = np.arange(array.ndim) permutation[np.sort(args)] = args return self.module.transpose(array, permutation) def _gammaln_torch(self, array): return self.module.special.gammaln(array) def _gammaln_jax(self, array): return self.jax.scipy.special.gammaln(array) def _sigmoid_torch(self, array): return self.module.sigmoid(array) def _sigmoid_jax(self, array): return self.jax.nn.sigmoid(array) def _logit_torch(self, array): return self.module.logit(array) def _logit_jax(self, array): return self.jax.scipy.special.logit(array) def _upsample2d_torch(self, array, scale_factor, method): U = self.module.nn.Upsample(scale_factor=scale_factor, mode=method) array = U(array) / scale_factor**2 return array def _upsample2d_jax(self, array, scale_factor, method): if method == "nearest": method = "bilinear" # no nearest neighbor interpolation in jax new_shape = list(array.shape) new_shape[-2] = array.shape[-2] * scale_factor new_shape[-1] = array.shape[-1] * scale_factor return self.jax.image.resize(array, new_shape, method=method) def _pad_torch(self, array, padding, mode): return self.module.nn.functional.pad(array, padding[-4:], mode=mode) def _pad_jax(self, array, padding, mode): if mode == "replicate": mode = "edge" padding = np.array(padding).reshape(-1, 2) return self.module.pad(array, padding, mode=mode) def _roll_torch(self, array, shifts, dims): return self.module.roll(array, shifts, dims=dims) def _roll_jax(self, array, shifts, dims): return self.module.roll(array, shifts, axis=dims) def _clamp_torch(self, array, min, max): return self.module.clamp(array, min, max) def _clamp_jax(self, array, min, max): return self.module.clip(array, min, max) def _long_torch(self, array): return array.long() def _long_jax(self, array): return self.module.astype(array, self.module.int64) def _conv2d_torch(self, input, kernel, padding, stride=1): return self.module.nn.functional.conv2d( input, kernel, padding=padding, stride=stride, ) def _conv2d_jax(self, input, kernel, padding, stride=1): return self.jax.lax.conv_general_dilated( input, kernel, window_strides=(stride, stride), padding=padding ) def _mean_torch(self, array, dim=None): return self.module.mean(array, dim=dim) def _mean_jax(self, array, dim=None): return self.module.mean(array, axis=dim) def _sum_torch(self, array, dim=None, keepdim=False): return self.module.sum(array, dim=dim, keepdim=keepdim) def _sum_jax(self, array, dim=None, keepdim=False): return self.module.sum(array, axis=dim, keepdims=keepdim) def _max_torch(self, array, dim=None): return array.amax(dim=dim) def _max_jax(self, array, dim=None): return self.module.max(array, axis=dim) def _topk_torch(self, array, k): return self.module.topk(array, k=k) def _topk_jax(self, array, k): return self.jax.lax.top_k(array, k=k) def _bessel_j1_torch(self, array): return self.module.special.bessel_j1(array) def _bessel_j1_jax(self, array): return self.jax.scipy.special.bessel_jn(array, v=1)[-1] def _bessel_k1_torch(self, array): return self.module.special.modified_bessel_k1(array) def _bessel_k1_jax(self, array): return self.jax.scipy.special.kn(1, array) def _lgamma_torch(self, array): return self.module.lgamma(array) def _lgamma_jax(self, array): return self.jax.lax.lgamma(array) def _grad_torch(self, func): return self.module.func.grad(func) def _grad_jax(self, func): return self.jax.grad(func) def _jacobian_torch(self, func, x, strategy="forward-mode", vectorize=True, create_graph=False): return self.module.autograd.functional.jacobian( func, x, strategy=strategy, vectorize=vectorize, create_graph=create_graph ) def _jacobian_jax(self, func, x, strategy="forward-mode", vectorize=True, create_graph=False): if "forward" in strategy: # n = x.size # eye = self.module.eye(n) # Jt = self.jax.vmap(lambda s: self.jax.jvp(func, (x,), (s,))[1])(eye) # return self.module.moveaxis(Jt, 0, -1) return self.jax.jacfwd(func)(x) return self.jax.jacrev(func)(x) def _jacfwd_torch(self, func, argnums=0): return self.module.func.jacfwd(func, argnums=argnums) def _jacfwd_jax(self, func, argnums=0): return self.jax.jacfwd(func, argnums=argnums) def _hessian_torch(self, func): return self.module.func.hessian(func) def _hessian_jax(self, func): return self.jax.hessian(func) def _vmap_torch(self, *args, in_dims=0, **kwargs): return self.module.vmap(*args, in_dims=in_dims, **kwargs) def _vmap_jax(self, *args, in_dims=0, **kwargs): return self.jax.vmap(*args, in_axes=in_dims, **kwargs) def _fill_at_indices_torch(self, array, indices, values): if isinstance(indices, self.module.Tensor) and indices.dtype != self.module.bool: # Long (integer) tensor indices: use index_put for vmap+jacfwd compatibility return array.index_put((indices,), values) # Bool tensor or tuple/slice indices: use clone + in-place array = array.clone() array[indices] = values return array def _fill_at_indices_jax(self, array, indices, values): return array.at[indices].set(values) def _add_at_indices_torch(self, array, indices, values): if isinstance(indices, self.module.Tensor) and indices.dtype != self.module.bool: # Long (integer) tensor indices: use index_put for vmap+jacfwd compatibility return array.index_put((indices,), values, accumulate=True) # Bool tensor or tuple/slice indices: use clone + in-place array = array.clone() array[indices] += values return array def _add_at_indices_jax(self, array, indices, values): return array.at[indices].add(values) def _flatten_torch(self, array, start_dim=0, end_dim=-1): return array.flatten(start_dim, end_dim) def _flatten_jax(self, array, start_dim=0, end_dim=-1): shape = tuple(array.shape) end_dim = (end_dim % len(shape)) + 1 new_shape = shape[:start_dim] + (-1,) + shape[end_dim:] return self.module.reshape(array, new_shape)
[docs] def arange(self, *args, dtype=None, device=None): return self.module.arange(*args, dtype=dtype, device=device)
[docs] def linspace(self, start, end, steps, dtype=None, device=None): return self.module.linspace(start, end, steps, dtype=dtype, device=device)
[docs] def meshgrid(self, *arrays, indexing="ij"): return self.module.meshgrid(*arrays, indexing=indexing)
[docs] def searchsorted(self, array, value): return self.module.searchsorted(array, value)
[docs] def any(self, array): return self.module.any(array)
[docs] def all(self, array): return self.module.all(array)
[docs] def log(self, array): return self.module.log(array)
[docs] def log10(self, array): return self.module.log10(array)
[docs] def exp(self, array): return self.module.exp(array)
[docs] def sin(self, array): return self.module.sin(array)
[docs] def cos(self, array): return self.module.cos(array)
[docs] def cosh(self, array): return self.module.cosh(array)
[docs] def sqrt(self, array): return self.module.sqrt(array)
[docs] def abs(self, array): return self.module.abs(array)
[docs] def floor(self, array): return self.module.floor(array)
[docs] def tanh(self, array): return self.module.tanh(array)
[docs] def arctan(self, array): return self.module.arctan(array)
[docs] def arctan2(self, y, x): return self.module.arctan2(y, x)
[docs] def arcsin(self, array): return self.module.arcsin(array)
[docs] def round(self, array): return self.module.round(array)
[docs] def zeros(self, shape, dtype=None, device=None): return self.module.zeros(shape, dtype=dtype, device=device)
[docs] def zeros_like(self, array, dtype=None): return self.module.zeros_like(array, dtype=dtype)
[docs] def ones(self, shape, dtype=None, device=None): return self.module.ones(shape, dtype=dtype, device=device)
[docs] def ones_like(self, array, dtype=None): return self.module.ones_like(array, dtype=dtype)
[docs] def empty(self, shape, dtype=None, device=None): return self.module.empty(shape, dtype=dtype, device=device)
[docs] def eye(self, n, dtype=None, device=None): return self.module.eye(n, dtype=dtype, device=device)
[docs] def diag(self, array): return self.module.diag(array)
[docs] def outer(self, a, b): return self.module.outer(a, b)
[docs] def minimum(self, a, b): return self.module.minimum(a, b)
[docs] def maximum(self, a, b): return self.module.maximum(a, b)
[docs] def isnan(self, array): return self.module.isnan(array)
[docs] def isfinite(self, array): return self.module.isfinite(array)
[docs] def nan_to_num(self, array, nan=0.0, posinf=None, neginf=None): return self.module.nan_to_num(array, nan=nan, posinf=posinf, neginf=neginf)
[docs] def where(self, condition, x, y): return self.module.where(condition, x, y)
[docs] def allclose(self, a, b, rtol=1e-5, atol=1e-8): return self.module.allclose(a, b, rtol=rtol, atol=atol)
@property def linalg(self): return self.module.linalg @property def fft(self): return self.module.fft @property def inf(self): return self.module.inf @property def bool(self): return self.module.bool @property def int32(self): return self.module.int32 @property def float32(self): return self.module.float32 @property def float64(self): return self.module.float64
backend = Backend()