Source code for everwillow._src.inference.uncertainty

"""Parameter uncertainty estimation from Hessian-based covariance.

This module provides functions for computing parameter uncertainties using
the Fisher information matrix, which is obtained by inverting the Hessian
of the negative log-likelihood at the maximum likelihood estimate.

The key relationship is the Cramér-Rao bound: the covariance matrix of
unbiased estimators is bounded below by the inverse Fisher information.
At the MLE, under regularity conditions, this bound is saturated.
"""

from __future__ import annotations

import typing as tp
from types import EllipsisType

import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PyTree

import everwillow._src.statelib as sl
from everwillow._src.statelib import V


[docs] def hessian_matrix( nll_fn: tp.Callable[[PyTree[V], PyTree], float], params: sl.State[V], observation: PyTree, *, fixed: sl.State[V | EllipsisType] | None = None, ) -> Float[Array, "n_free n_free"]: """Compute the Hessian matrix of the NLL at given parameters. The Hessian H_ij = ∂²NLL/∂θ_i∂θ_j is computed only for free (non-fixed) parameters using JAX automatic differentiation. Args: nll_fn: Negative log-likelihood function taking (params, observation). params: Full parameter state at which to evaluate the Hessian. observation: Observed data passed to nll_fn. fixed: Parameters to treat as constants (excluded from Hessian). Returns: 2D JAX array of shape (n_free, n_free). """ if not isinstance(params, sl.State): msg = "params must be a State" # type: ignore[unreachable] raise TypeError(msg) if fixed is None: fixed = sl.State.from_pytree({}) if not isinstance(fixed, sl.State): msg = "fixed must be a State or None" # type: ignore[unreachable] raise TypeError(msg) # Split into fixed and free fixed_state, free_state = sl.partition(params, predicate=lambda key, _: key in fixed) # Get flat array of free values free_keys = tuple(free_state.notnone.keys()) flat_values = jnp.array([free_state[k] for k in free_keys]) def _flat_nll(flat_free: Float[Array, ...]) -> Float[Array, ""]: """Compute the negative log-likelihood for the flat parameter vector. Necessary for jax.hessian. Args: flat_free (Float[Array, "..."]): Flattened free parameter values. Returns: Float[Array, ""]: Negative log-likelihood value. """ free_mapping = {k: flat_free[i] for i, k in enumerate(free_keys)} new_free = sl.update(free_state, updates=free_mapping) combined = sl.combine_partitions(fixed_state, new_free) return nll_fn(combined.to_pytree(), observation) return jax.hessian(_flat_nll)(flat_values)
[docs] def covariance_matrix( nll_fn: tp.Callable[[PyTree[V], PyTree], float], params: sl.State[V], observation: PyTree, *, fixed: sl.State[V | EllipsisType] | None = None, ) -> Float[Array, "nparams nparams"]: """Compute the covariance matrix (inverse Hessian) at given parameters. The Fisher information matrix is obtained by inverting the Hessian: Cov(θ) = H⁻¹ where H_ij = ∂²NLL/∂θ_i∂θ_j This is the Laplace approximation to the posterior covariance. Args: nll_fn: Negative log-likelihood function taking (params, observation). params: Full parameter state (typically fitted values). observation: Observed data passed to nll_fn. fixed: Parameters to exclude from covariance computation. Returns: 2D JAX array of shape (n_free, n_free). """ hess = hessian_matrix(nll_fn, params, observation, fixed=fixed) # Invert Hessian to get Fisher information matrix (covariance) return jnp.linalg.inv(hess)
[docs] def correlation_matrix( nll_fn: tp.Callable[[PyTree[V], PyTree], float], params: sl.State[V], observation: PyTree, *, fixed: sl.State[V | EllipsisType] | None = None, ) -> jax.Array: """Compute the correlation matrix (normalized covariance). The covariance matrix is normalized so that diagonal entries are 1.0: ρ_ij = Cov_ij / √(Cov_ii · Cov_jj) Args: nll_fn: Negative log-likelihood function taking (params, observation). params: Full parameter state at which to evaluate. observation: Observed data passed to nll_fn. fixed: Parameters to exclude from correlation computation. Returns: 2D JAX array with diagonal = 1.0, off-diagonal in [-1, 1]. """ cov = covariance_matrix(nll_fn, params, observation, fixed=fixed) # Normalize: ρ_ij = Cov_ij / (σ_i · σ_j) d = jnp.sqrt(jnp.diag(cov)) corr = cov / jnp.outer(d, d) # Ensure diagonal is exactly 1.0 (numerical stability) return jnp.where(jnp.eye(corr.shape[0], dtype=corr.dtype), 1.0, corr)
[docs] def uncertainties( nll_fn: tp.Callable[[PyTree[V], PyTree], float], params: sl.State[V], observation: PyTree, *, fixed: sl.State[V | EllipsisType] | None = None, ) -> sl.State[V]: """Extract parameter uncertainties as sqrt(diag(covariance)). The uncertainties are the square roots of the diagonal of the covariance matrix (inverse Fisher information), following the Cramér-Rao bound: σ_i = √(Cov_ii) = √((H⁻¹)_ii) Args: nll_fn: Negative log-likelihood function taking (params, observation). params: Full parameter state (typically fitted values). observation: Observed data passed to nll_fn. fixed: Parameters to exclude from uncertainty computation. Returns: State containing uncertainty values for free parameters only. """ if fixed is None: fixed = sl.State.from_pytree({}) # Get covariance matrix for free parameters cov = covariance_matrix(nll_fn, params, observation, fixed=fixed) # Cramér-Rao: uncertainties are sqrt of diagonal elements stderrs = jnp.sqrt(jnp.diag(cov)) # Get free_state with same structure/ordering as used for Hessian fixed_state, free_state = sl.partition(params, predicate=lambda key, _: key in fixed) # Unflatten stderrs back into the same pytree structure as free_state _, treedef = jax.tree_util.tree_flatten(free_state) free_uncertainty = jax.tree_util.tree_unflatten(treedef, stderrs) fixed_uncertainty = jax.tree.map(lambda _: None, fixed_state) return sl.combine_partitions(fixed_uncertainty, free_uncertainty)