import numpy as np
from ...errors import OptimizeStopFail, OptimizeStopSuccess
from ...backend_obj import backend
from ... import config
[docs]
def nll(D, M, W):
"""
Negative log-likelihood for Gaussian noise.
D: data
M: model prediction
W: weights
"""
return 0.5 * backend.sum(W * (D - M) ** 2)
[docs]
def nll_poisson(D, M):
"""
Negative log-likelihood for Poisson noise.
D: data
M: model prediction
"""
return backend.sum(M - D * backend.log(M + 1e-10)) # Adding small value to avoid log(0)
[docs]
def gradient(J, W, D, M):
return J.T @ (W * (D - M))[:, None]
[docs]
def gradient_poisson(J, D, M):
return J.T @ (D / M - 1)[:, None]
[docs]
def hessian(J, W):
return J.T @ (W[:, None] * J)
[docs]
def hessian_poisson(J, D, M):
return J.T @ ((D / (M**2 + 1e-10))[:, None] * J)
[docs]
def damp_hessian(hess, L):
I = backend.eye(len(hess), dtype=config.DTYPE, device=config.DEVICE)
D = backend.ones_like(hess) - I
return hess * (I + D / (1 + L)) + L * I * backend.diag(hess)
[docs]
def solve(hess, grad, L):
hessD = damp_hessian(hess, L) # (N, N)
while True:
try:
h = backend.linalg.solve(hessD, grad)
break
except backend.LinAlgErr:
hessD = hessD + L * backend.eye(len(hessD), dtype=config.DTYPE, device=config.DEVICE)
L = L * 2
return hessD, h
[docs]
def lm_step(
x,
data,
model,
weight,
jacobian,
L=1.0,
Lup=9.0,
Ldn=11.0,
tolerance=1e-4,
likelihood="gaussian",
):
L0 = L
M0 = backend.detach(model(x)) # (M,)
J = backend.detach(jacobian(x)) # (M, N)
if likelihood == "gaussian":
nll0 = nll(data, M0, weight).item()
grad = gradient(J, weight, data, M0) # (N, 1)
hess = hessian(J, weight) # (N, N)
elif likelihood == "poisson":
nll0 = nll_poisson(data, M0).item()
grad = gradient_poisson(J, data, M0) # (N, 1)
hess = hessian_poisson(J, data, M0) # (N, N)
else:
raise ValueError(f"Unsupported likelihood: {likelihood}")
del J
if backend.allclose(grad, backend.zeros_like(grad)):
raise OptimizeStopSuccess("Gradient is zero, optimization converged.")
best = {"x": backend.zeros_like(x), "nll": nll0, "L": L}
scary = {"x": None, "nll": np.inf, "L": None, "rho": np.inf}
nostep = True
improving = None
for i in range(10):
hessD, h = solve(hess, grad, L) # (N, N), (N, 1)
M1 = model(x + h.squeeze(1)) # (M,)
if likelihood == "gaussian":
nll1 = nll(data, M1, weight).item()
elif likelihood == "poisson":
nll1 = nll_poisson(data, M1).item()
# Handle nan chi2
if not np.isfinite(nll1):
L *= Lup
if improving is True:
break
improving = False
continue
if backend.allclose(h, backend.zeros_like(h)) and L < 0.1:
if i == 0:
raise OptimizeStopSuccess("Step with zero length means optimization complete.")
break
# actual nll improvement vs expected from linearization
rho = (nll0 - nll1) / backend.abs(h.T @ hessD @ h - 2 * grad.T @ h).item()
if (nll1 < (nll0 + tolerance) and abs(rho - 1) < abs(scary["rho"] - 1)) or (
nll1 < scary["nll"] and rho > -10
):
scary = {"x": x + h.squeeze(1), "nll": nll1, "L": L0, "rho": rho}
# Avoid highly non-linear regions
if rho < 0.1 or rho > 2:
L *= Lup
if improving is True:
break
improving = False
continue
if nll1 < best["nll"]: # new best
best = {"x": x + h.squeeze(1), "nll": nll1, "L": L}
nostep = False
L /= Ldn
if L < 1e-8 or improving is False:
break
improving = True
elif improving is True: # were improving, now not improving
break
else: # not improving and bad chi2, damp more
L *= Lup
if L >= 1e9:
break
improving = False
# If we are improving chi2 by more than 10% then we can stop
if (best["nll"] - nll0) / nll0 < -0.1:
break
if nostep:
if scary["x"] is not None and (scary["nll"] - nll0) / nll0 < tolerance:
return scary
raise OptimizeStopFail("Could not find step to improve chi^2")
return best