import numpy as np
from ...errors import OptimizeStopFail, OptimizeStopSuccess
from ...backend_obj import backend
from ... import config
[docs]
def nll(D, M, W):
"""
Negative log-likelihood for Gaussian noise.
D: data
M: model prediction
W: weights
"""
return 0.5 * backend.sum(W * (D - M) ** 2, dim=-1)
[docs]
def nll_poisson(D, M):
"""
Negative log-likelihood for Poisson noise.
D: data
M: model prediction
"""
return backend.sum(M - D * backend.log(M + 1e-10), dim=-1) # Adding small value to avoid log(0)
[docs]
def gradient(J, W, D, M):
return J.T @ (W * (D - M)).reshape(D.shape + (1,))
[docs]
def gradient_poisson(J, D, M):
return J.T @ (D / M - 1).reshape(D.shape + (1,))
[docs]
def hessian(J, W):
return J.T @ (W.reshape(W.shape + (1,)) * J)
[docs]
def hessian_poisson(J, D, M):
return J.T @ ((D / (M**2 + 1e-10)).reshape(D.shape + (1,)) * J)
[docs]
def damp_hessian(hess, L):
I = backend.eye(len(hess), dtype=config.DTYPE, device=config.DEVICE)
D = backend.ones_like(hess) - I
return hess * (I + D / (1 + L)) + L * I * backend.diag(hess)
[docs]
def rho(nll0, nll1, h, hessD, grad):
return (nll0 - nll1) / backend.abs(h.T @ hessD @ h - 2 * grad.T @ h)
[docs]
def solve(hess, grad, L):
hessD = damp_hessian(hess, L) # (N, N)
while True:
try:
h = backend.linalg.solve(hessD, grad)
break
except backend.LinAlgErr:
hessD = hessD + L * backend.eye(len(hessD), dtype=config.DTYPE, device=config.DEVICE)
L = L * 2
return hessD, h
[docs]
def lm_step(
x,
data,
model,
weight,
jacobian,
L=1.0,
Lup=9.0,
Ldn=11.0,
tolerance=1e-4,
likelihood="gaussian",
):
L0 = L
M0 = backend.detach(model(x)) # (M,)
J = backend.detach(jacobian(x)) # (M, N)
if likelihood == "gaussian":
nll0 = nll(data, M0, weight).item()
grad = gradient(J, weight, data, M0) # (N, 1)
hess = hessian(J, weight) # (N, N)
elif likelihood == "poisson":
nll0 = nll_poisson(data, M0).item()
grad = gradient_poisson(J, data, M0) # (N, 1)
hess = hessian_poisson(J, data, M0) # (N, N)
else:
raise ValueError(f"Unsupported likelihood: {likelihood}")
del J
if backend.allclose(grad, backend.zeros_like(grad)):
raise OptimizeStopSuccess("Gradient is zero, optimization converged.")
best = {"x": backend.zeros_like(x), "nll": nll0, "L": L}
scary = {"x": None, "nll": np.inf, "L": None, "rho": np.inf}
nostep = True
improving = None
for i in range(10):
hessD, h = solve(hess, grad, L) # (N, N), (N, 1)
M1 = model(x + h.squeeze(1)) # (M,)
if likelihood == "gaussian":
nll1 = nll(data, M1, weight).item()
elif likelihood == "poisson":
nll1 = nll_poisson(data, M1).item()
# Handle nan chi2
if not np.isfinite(nll1):
L *= Lup
if improving is True:
break
improving = False
continue
if backend.allclose(h, backend.zeros_like(h)) and L < 0.1:
if i == 0:
raise OptimizeStopSuccess("Step with zero length means optimization complete.")
break
# actual nll improvement vs expected from linearization
_rho = rho(nll0, nll1, h, hessD, grad).item()
if (nll1 < (nll0 + tolerance) and abs(_rho - 1) < abs(scary["rho"] - 1)) or (
nll1 < scary["nll"] and _rho > -10
):
scary = {"x": x + h.squeeze(1), "nll": nll1, "L": L0, "rho": _rho}
# Avoid highly non-linear regions
if _rho < 0.1 or _rho > 2:
L *= Lup
if improving is True:
break
improving = False
continue
if nll1 < best["nll"]: # new best
best = {"x": x + h.squeeze(1), "nll": nll1, "L": L}
nostep = False
L /= Ldn
if L < 1e-8 or improving is False:
break
improving = True
elif improving is True: # were improving, now not improving
break
else: # not improving and bad chi2, damp more
L *= Lup
if L >= 1e9:
break
improving = False
# If we are improving chi2 by more than 10% then we can stop
if (best["nll"] - nll0) / nll0 < -0.1:
break
if nostep:
if scary["x"] is not None and (scary["nll"] - nll0) / nll0 < tolerance:
return scary
raise OptimizeStopFail("Could not find step to improve chi^2")
return best
[docs]
def batch_lm_step(
x,
data,
model,
weight,
mask,
jacobian,
L=1.0,
Lup=9.0,
Ldn=11.0,
tolerance=1e-4,
likelihood="gaussian",
max_step_iter=3,
):
L0 = L # (D,)
M0 = backend.detach(model(x)) # (D, M)
J = backend.detach(jacobian(x)) # (D, M, N)
data = data * mask
M0 = M0 * mask
weight = weight * mask
J = J * mask.reshape(mask.shape + (1,))
if likelihood == "gaussian":
nll0 = nll(data, M0, weight) # (D,)
grad = backend.vmap(gradient)(J, weight, data, M0) # (D, N, 1)
hess = backend.vmap(hessian)(J, weight) # (D, N, N)
elif likelihood == "poisson":
nll0 = nll_poisson(data, M0) # (D,)
grad = backend.vmap(gradient_poisson)(J, data, M0) # (D, N, 1)
hess = backend.vmap(hessian_poisson)(J, data, M0) # (D, N, N)
else:
raise ValueError(f"Unsupported likelihood: {likelihood}")
del J
new_x = backend.copy(x)
new_nll = backend.copy(nll0)
new_L = backend.copy(L)
for _ in range(max_step_iter):
hessD, h = backend.vmap(solve)(hess, grad, new_L) # (D, N, N), (D, N, 1)
M1 = model(x + h.squeeze(2)) # (D, M)
if likelihood == "gaussian":
nll1 = nll(data, M1, weight) # (D,)
elif likelihood == "poisson":
nll1 = nll_poisson(data, M1) # (D,)
# actual nll improvement vs expected from linearization
_rho = backend.vmap(rho)(nll0, nll1, h, hessD, grad).reshape(-1) # (D,)
good = backend.isfinite(nll1) & (nll1 < new_nll) & (_rho > 0.1) & (_rho < 2)
new_x = backend.where(good[:, None], x + h.squeeze(2), new_x)
new_nll = backend.where(good, nll1, new_nll)
new_L = backend.where(good, new_L / Ldn, new_L * Lup)
return {"x": new_x, "nll": backend.to_numpy(new_nll), "L": new_L}