Source code for everwillow._src.inference.fitting

"""Core fitting functionality for statistical inference."""

from __future__ import annotations

import contextlib
import dataclasses
import typing as tp
from types import EllipsisType

import equinox as eqx
import jax
import jax.numpy as jnp
import optimistix as optx
import orbax.checkpoint as ocp
from jaxtyping import PyTree
from tqdm import tqdm

import everwillow._src.parameters as ewp
import everwillow._src.statelib as sl
from everwillow._src.statelib import K, V

Args: tp.TypeAlias = tuple[
    sl.State[V],  # fixed_state
    sl.State[ewp.TransformBase],  # bounds
]

if tp.TYPE_CHECKING:
    from everwillow._src.inference.callback import Callback


def _reconstruct_full_state(
    free_state: sl.State[V],
    *,
    args: Args,
) -> sl.State[V]:
    """Reconstruct full parameter pytree from free state and Args."""
    (fixed_state, bounds) = args

    # Combine partitions back together (still in unbounded space)
    full_state_transformed = sl.combine_partitions(fixed_state, free_state)

    # Transform back to bounded space for NLL evaluation
    return ewp.wrap(full_state_transformed, bounds.mapping)


[docs] class FitResult(eqx.Module, tp.Generic[V]): """Result of a fit operation.""" params: sl.State[V] #: Fitted parameter state. nll: jax.Array #: Negative log-likelihood at the optimum. success: jax.Array #: Whether the optimisation converged. solver_result: PyTree #: Raw solver result.
@dataclasses.dataclass(frozen=True, slots=True) class PreparedNllFn: """Sentinel wrapper around a combined NLL produced by :func:`prepare`. This is a thin callable that delegates to the inner function. Its only purpose is to enable ``isinstance`` checks so that :func:`prepare` can reject nested calls. """ _fn: tp.Callable = dataclasses.field(repr=False) def __call__(self, params: PyTree, observation: PyTree) -> float: return self._fn(params, observation) def prepare( nlls: tp.Sequence[tp.Callable], states: tp.Sequence[sl.State], ) -> tuple[PreparedNllFn, sl.State]: """Combine multiple models for a joint fit. Each NLL function and its corresponding state are merged so that a single call to :func:`fit` optimises all parameters simultaneously. Shared parameters (aligned via :func:`~everwillow._src.statelib.transform.apply_transformations`) are handled automatically. Args: nlls: Sequence of NLL callables ``(params_pytree, observation) -> float``. states: Sequence of :class:`~everwillow._src.statelib.state.State` instances, one per NLL. Returns: ``(combined_nll, merged_state)`` ready to pass to :func:`fit`. Raises: TypeError: If any element of *nlls* is already a :class:`PreparedNllFn`, or any element of *states* is a merged state. ValueError: If *nlls* and *states* have different lengths, or are empty. """ if len(nlls) != len(states): msg = f"nlls and states must have the same length, got {len(nlls)} and {len(states)}" raise ValueError(msg) if len(nlls) == 0: msg = "nlls and states must not be empty" raise ValueError(msg) for i, nll in enumerate(nlls): if isinstance(nll, PreparedNllFn): msg = f"nlls[{i}] was already produced by prepare() — nested prepare() is not allowed" raise TypeError(msg) for i, state in enumerate(states): if not isinstance(state, sl.State): msg = f"states[{i}] is not a State instance" # type: ignore[unreachable] raise TypeError(msg) if state.is_merged: msg = f"states[{i}] is already a merged state — nested prepare() is not allowed" raise TypeError(msg) merged_state = sl.merge(*states) def _combined_nll(params: PyTree, observation: PyTree) -> float: total = 0.0 # params is a tuple of pytrees from merged_state.to_pytree() for nll_fn, p in zip(nlls, params, strict=True): total = total + nll_fn(p, observation) return total return PreparedNllFn(_combined_nll), merged_state def _minimize( wrapped_nll: tp.Callable[[sl.State[V], Args], float], solver: optx.AbstractMinimiser, y0: sl.State[V], args: Args, *, max_steps: int, **minimise_kwargs, ) -> optx.Solution: """Non-interactive minimization using optx.minimise.""" return optx.minimise( wrapped_nll, solver, y0=y0, args=args, max_steps=max_steps, **minimise_kwargs, ) class _ProgressUpdater: """Progress bar updater that can adjust total on completion.""" def __init__(self, pbar): self._pbar = pbar def update(self, step: int, nll_value: float) -> None: self._pbar.n = step self._pbar.set_postfix(NLL=f"{nll_value:.6f}") self._pbar.refresh() def finalize(self, final_step: int, nll_value: float) -> None: """Set total to actual steps taken and complete the bar.""" self._pbar.total = final_step self._pbar.n = final_step self._pbar.set_postfix(NLL=f"{nll_value:.6f}") self._pbar.refresh() @contextlib.contextmanager def _make_progress_context( enabled: bool, max_steps: int, ) -> tp.Generator[_ProgressUpdater | None]: """Create progress bar context manager. Yields a _ProgressUpdater when enabled, else None. """ if not enabled: yield None return pbar = tqdm(total=max_steps, desc="Minimizing", unit="step") try: yield _ProgressUpdater(pbar) finally: pbar.close() def _iminimize( wrapped_nll: tp.Callable[[sl.State[V], Args], float], solver: optx.AbstractMinimiser, y0: sl.State[V], args: Args, *, max_steps: int, progress: bool, checkpoint_manager: ocp.CheckpointManager | None, callbacks: tp.Iterable[Callback] | None, solver_options: dict[str, tp.Any] | None = None, ) -> optx.Solution: """Interactive minimization with step-by-step iteration and progress bar.""" # If we have a checkpoint_manager _and_ something is stored in its path, we'll # automatically load and start from the latest checkpointed iteration if checkpoint_manager is not None and checkpoint_manager.latest_step() is not None: iteration = tp.cast(int, checkpoint_manager.latest_step()) abstract_y0 = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, y0) y0 = tp.cast( sl.State[V], checkpoint_manager.restore( iteration, args=ocp.args.StandardRestore(abstract_y0), ), ) else: iteration = 0 # Convert y0 leaves to JAX arrays (required for solver.init which calls tree_full_like) y0 = jax.tree_util.tree_map(lambda x: jnp.asarray(x, dtype=jnp.float64), y0) # Wrap nll to return (f, aux) tuple and ensure outputs are arrays def fn_with_aux(y, fn_args): result = wrapped_nll(y, fn_args) return jnp.asarray(result), None # Infer output structure for solver initialization via eval_shape f_struct = jax.eval_shape(lambda: fn_with_aux(y0, args)[0]) aux_struct = None options = solver_options if solver_options is not None else {} tags: frozenset[object] = frozenset() # Initialize solver state state = solver.init(fn_with_aux, y0, args, options, f_struct, aux_struct, tags) y = y0 aux = None done, result = solver.terminate(fn_with_aux, y, args, options, state, tags) # JIT the hot path for performance step = eqx.filter_jit(eqx.Partial(solver.step, fn=fn_with_aux, args=args, options=options, tags=tags)) terminate = eqx.filter_jit(eqx.Partial(solver.terminate, fn=fn_with_aux, args=args, options=options, tags=tags)) with _make_progress_context(progress, max_steps) as updater: while not done and iteration < max_steps: # Call user callback with current state (user can access state.f_info.f for NLL) if callbacks is not None: for callback in callbacks: callback(iteration, y, state) # Perform one solver step y, state, aux = step(y=y, state=state) iteration += 1 # checkpoint if we have a manager if checkpoint_manager is not None: checkpoint_manager.save(iteration, args=ocp.args.StandardSave(y)) # Check termination done, result = terminate(y=y, state=state) # Update progress bar with current NLL if updater is not None: current_nll = float(state.f_info.f) updater.update(iteration, current_nll) # Final progress bar update - set total to actual steps taken if updater is not None: updater.finalize(iteration, float(state.f_info.f)) # Postprocess y_final, aux_final, stats = solver.postprocess(fn_with_aux, y, aux, args, options, state, tags, result) # Include iteration info in stats stats["num_steps"] = jnp.asarray(iteration) stats["max_steps"] = max_steps return optx.Solution( value=y_final, result=result, aux=aux_final, stats=stats, state=state, ) def _fit( nll_fn: tp.Callable[[PyTree[V], PyTree], float], params: sl.State[V], observation: PyTree, *, fixed: sl.State[V | EllipsisType] | None = None, bounds: sl.State[ewp.TransformBase] | None = None, solver: optx.AbstractMinimiser | None = None, interactive: bool = False, max_steps: int = 256, progress: bool = True, checkpoint_manager: ocp.CheckpointManager | None = None, callbacks: tp.Iterable[Callback] | None = None, solver_options: dict[str, tp.Any] | None = None, **minimise_kwargs, ) -> FitResult[V]: """Internal fit implementation that handles both interactive and non-interactive modes. This function performs shared setup (validation, partitioning, transforms) then dispatches to _minimize() or _iminimize() based on the `interactive` flag. """ # Validate inputs if not isinstance(params, sl.State): msg = "params must be a State" # type: ignore[unreachable] raise TypeError(msg) # Normalize fixed and bounds inputs if fixed is None: fixed = sl.State.from_pytree({}) if not isinstance(fixed, sl.State): msg = "fixed must be a State or None" # type: ignore[unreachable] raise TypeError(msg) if bounds is None: bounds = sl.State.from_pytree({}) if not isinstance(bounds, sl.State): msg = "bounds must be a State or None" # type: ignore[unreachable] raise TypeError(msg) # Set fixed values updated_params = sl.update(params, updates=fixed) # Apply bounds transformations (convert to unbounded space) param_state_transformed = ewp.unwrap(updated_params, transform_mapping=bounds) # Partition state into fixed and free components def predicate(key: K, value: V) -> bool: del value # unused return key in fixed fixed_state, free_state = sl.partition( param_state_transformed, predicate=predicate, ) # Prepare args for reconstructing full state args: Args = (fixed_state, bounds) # Wrap nll to only take free parameters def wrapped_nll(new_state, fn_args): full_state = _reconstruct_full_state(new_state, args=fn_args) return nll_fn(full_state.to_pytree(), observation) # Set up solver if solver is None: solver = optx.BFGS(rtol=1e-5, atol=1e-5) # Dispatch to appropriate backend if interactive: solution = _iminimize( wrapped_nll, solver, free_state, args, max_steps=max_steps, progress=progress, callbacks=callbacks, checkpoint_manager=checkpoint_manager, solver_options=solver_options, ) else: solution = _minimize( wrapped_nll, solver, free_state, args, max_steps=max_steps, **minimise_kwargs, ) # Reconstruct full fitted state fitted_state = _reconstruct_full_state(solution.value, args=args) # Return result return FitResult( params=fitted_state, nll=wrapped_nll(solution.value, args), success=jax.numpy.asarray(solution.result == optx.RESULTS.successful), solver_result=solution, )
[docs] def fit( nll_fn: tp.Callable[[PyTree[V], PyTree], float], params: sl.State[V], observation: PyTree, *, fixed: sl.State[V | EllipsisType] | None = None, bounds: sl.State[ewp.TransformBase] | None = None, solver: optx.AbstractMinimiser | None = None, max_steps: int = 256, **minimise_kwargs, ) -> FitResult[V]: """Perform an unconditional maximum-likelihood fit. The negative log-likelihood (NLL) provided via ``nll_fn`` is minimised with respect to all parameters except those explicitly marked as fixed. Internally the parameter pytree is converted into a :class:`~everwillow._src.statelib.state.State` so that subsets of the state can be frozen using :func:`everwillow._src.statelib.state.partition`. Parameter bounds are supported through automatic transformation to unbounded space. Args: nll_fn: Callable returning the scalar NLL. It must accept the parameter pytree as its first argument and observation data as its second. params: Initial parameter values organised as a state (e.g. mapping or nested containers). observation: Observed data passed to ``nll_fn``. Can be any pytree structure (dict, array, nested containers, etc.). fixed: Optional state of canonicalized keys to fixed values for identifying parameters that should remain unchanged during the fit. bounds: Optional state of :class:`~everwillow._src.parameters.transforms.TransformBase` instances. When provided, parameters are unwrapped via the transform's ``unwrap`` method prior to optimisation and wrapped back afterwards. solver: :class:`optimistix.AbstractMinimiser` instance to use. Defaults to :class:`optimistix.BFGS`. max_steps: Maximum number of optimization steps. Defaults to 256. **minimise_kwargs: Additional keyword arguments forwarded to :func:`optimistix.minimise`. Returns: :class:`FitResult` containing the fitted parameters and diagnostics. Examples: >>> import everwillow as ew >>> import everwillow._src.statelib as sl >>> # Basic usage >>> def my_nll(params, observation): ... return (params["mu"] - observation["target"])**2 >>> initial_params = sl.State.from_pytree({"mu": 0.0}) >>> observed = {"target": 2.0} >>> result = ew.fit(my_nll, initial_params, observed) >>> result.params["mu"] # Should be close to 2.0 >>> # Fix 'sigma' while fitting mu >>> def my_nll(params, observation): ... return (params["mu"] - observation["mu_target"]) ** 2 + ( ... params["sigma"] - observation["sigma_target"] ... ) ** 2 >>> initial_params = sl.State.from_pytree({"mu": 0.0, "sigma": 0.5}) >>> observed = {"mu_target": 2.0, "sigma_target": 1.0} >>> fixed = sl.State.from_pytree({"sigma": ...}) >>> result = ew.fit(my_nll, initial_params, observed, fixed=fixed) >>> result.params["sigma"] # Remains fixed 0.5 >>> # With parameter bounds >>> from everwillow._src.parameters.transforms import MinuitTransform >>> bounds = sl.State.from_pytree({"mu": MinuitTransform(lower=0.0, upper=5.0)}) >>> result = ew.fit(my_nll, initial_params, observed, bounds=bounds) >>> 0.0 <= result.params["mu"] <= 5.0 # Respects bounds True """ return _fit( nll_fn, params, observation, fixed=fixed, bounds=bounds, solver=solver, interactive=False, max_steps=max_steps, **minimise_kwargs, )
[docs] def ifit( nll_fn: tp.Callable[[PyTree[V], PyTree], float], params: sl.State[V], observation: PyTree, *, fixed: sl.State[V | EllipsisType] | None = None, bounds: sl.State[ewp.TransformBase] | None = None, solver: optx.AbstractMinimiser | None = None, max_steps: int = 256, progress: bool = True, checkpoint_manager: ocp.CheckpointManager | None = None, callbacks: tp.Iterable[Callback] | None = None, solver_options: dict[str, tp.Any] | None = None, ) -> FitResult[V]: """Perform an interactive maximum-likelihood fit with progress bar and callbacks. This function is similar to :func:`fit` but provides interactive features: a rich progress bar displayed during optimization, and an optional callback function invoked at each iteration. Note: This function is not JIT-compatible due to Python side effects in the iteration loop. Use :func:`fit` if you need JIT compilation. Args: nll_fn: Callable returning the scalar NLL. It must accept the parameter pytree as its first argument and observation data as its second. params: Initial parameter values organised as a state. observation: Observed data passed to ``nll_fn``. Can be any pytree structure (dict, array, nested containers, etc.). fixed: Optional state of canonicalized keys to fixed values for identifying parameters that should remain unchanged during the fit. bounds: Optional state of :class:`~everwillow._src.parameters.transforms.TransformBase` instances for parameter bounds. solver: :class:`optimistix.AbstractMinimiser` instance to use. Defaults to :class:`optimistix.BFGS`. max_steps: Maximum number of optimization steps. Defaults to 256. progress: Whether to display a rich progress bar. Defaults to True. checkpoint_manager: Optional ocp.CheckpointManager that checkpoints during interactive fitting, if checkpoints exist already under the provided path, the fit will automatically continue from the last checkpointed step. callbacks: Optional function(s) called each iteration with signature ``(iteration: int, y: State, state: SolverState) -> None``. This matches the pattern used by ``solver.step`` and ``solver.terminate``. The NLL value can be accessed via ``state.f_info.f``. solver_options: Optional dict of solver-specific options passed to ``solver.init``. Returns: :class:`FitResult` containing the fitted parameters and diagnostics. Examples: >>> import everwillow as ew >>> import everwillow._src.statelib as sl >>> # Interactive fit with progress bar >>> def my_nll(params, observation): ... return (params["mu"] - observation["target"])**2 >>> initial_params = sl.State.from_pytree({"mu": 0.0}) >>> observed = {"target": 2.0} >>> result = ew.ifit(my_nll, initial_params, observed) >>> # With custom callback to record history >>> history = ew.inference.HistoryCallback() >>> result = ew.ifit(my_nll, initial_params, callbacks=[history]) >>> # Disable progress bar >>> result = ew.ifit(my_nll, initial_params, progress=False) Example with Checkpointing: >>> import orbax.checkpoint as ocp >>> mngr = ocp.CheckpointManager('/tmp/my_fit') >>> result = ew.ifit(my_nll, initial_params, max_steps=10, progress=False, checkpoint_manager=mngr) >>> mngr.wait_until_finished() # checkpointing is async, so let's make sure everything is checkpointed >>> # run for another 10 steps (increase max_steps to 20); recover from latest checkpoint by reusing the `mngr`: >>> result = ew.ifit(my_nll, initial_params, max_steps=20, progress=False, checkpoint_manager=mngr) >>> mngr.wait_until_finished() """ return _fit( nll_fn, params, observation, fixed=fixed, bounds=bounds, solver=solver, interactive=True, max_steps=max_steps, progress=progress, checkpoint_manager=checkpoint_manager, callbacks=callbacks, solver_options=solver_options, )