Fitting

Fitting#

Core fitting functionality for statistical inference.

class everwillow.fitting.FitResult(params, nll, success, solver_result)[source]#

Bases: Module, Generic[V]

Result of a fit operation.

everwillow.fitting.fit(nll_fn, params, observation, *, fixed=None, bounds=None, solver=None, max_steps=256, **minimise_kwargs)[source]#

Perform an unconditional maximum-likelihood fit.

The negative log-likelihood (NLL) provided via nll_fn is minimised with respect to all parameters except those explicitly marked as fixed. Internally the parameter pytree is converted into a State so that subsets of the state can be frozen using everwillow._src.statelib.state.partition().

Parameter bounds are supported through automatic transformation to unbounded space.

Parameters:
  • nll_fn (Callable[[PyTree[~V], PyTree], float]) – Callable returning the scalar NLL. It must accept the parameter pytree as its first argument and observation data as its second.

  • params (State[V]) – Initial parameter values organised as a state (e.g. mapping or nested containers).

  • observation (PyTree) – Observed data passed to nll_fn. Can be any pytree structure (dict, array, nested containers, etc.).

  • fixed (State[V | EllipsisType] | None) – Optional state of canonicalized keys to fixed values for identifying parameters that should remain unchanged during the fit.

  • bounds (State[TransformBase] | None) – Optional state of TransformBase instances. When provided, parameters are unwrapped via the transform’s unwrap method prior to optimisation and wrapped back afterwards.

  • solver (AbstractMinimiser | None) – optimistix.AbstractMinimiser instance to use. Defaults to optimistix.BFGS.

  • max_steps (int) – Maximum number of optimization steps. Defaults to 256.

  • **minimise_kwargs – Additional keyword arguments forwarded to optimistix.minimise().

Returns:

FitResult containing the fitted parameters and diagnostics.

Return type:

FitResult[V]

Examples

>>> import everwillow as ew
>>> import everwillow._src.statelib as sl
>>> # Basic usage
>>> def my_nll(params, observation):
...     return (params["mu"] - observation["target"])**2
>>> initial_params = sl.State.from_pytree({"mu": 0.0})
>>> observed = {"target": 2.0}
>>> result = ew.fit(my_nll, initial_params, observed)
>>> result.params["mu"]  # Should be close to 2.0
>>> # Fix 'sigma' while fitting mu
>>> def my_nll(params, observation):
...     return (params["mu"] - observation["mu_target"]) ** 2 + (
...         params["sigma"] - observation["sigma_target"]
...     ) ** 2
>>> initial_params = sl.State.from_pytree({"mu": 0.0, "sigma": 0.5})
>>> observed = {"mu_target": 2.0, "sigma_target": 1.0}
>>> fixed = sl.State.from_pytree({"sigma": ...})
>>> result = ew.fit(my_nll, initial_params, observed, fixed=fixed)
>>> result.params["sigma"]  # Remains fixed
0.5
>>> # With parameter bounds
>>> from everwillow._src.parameters.transforms import MinuitTransform
>>> bounds = sl.State.from_pytree({"mu": MinuitTransform(lower=0.0, upper=5.0)})
>>> result = ew.fit(my_nll, initial_params, observed, bounds=bounds)
>>> 0.0 <= result.params["mu"] <= 5.0  # Respects bounds
True
everwillow.fitting.ifit(nll_fn, params, observation, *, fixed=None, bounds=None, solver=None, max_steps=256, progress=True, checkpoint_manager=None, callbacks=None, solver_options=None)[source]#

Perform an interactive maximum-likelihood fit with progress bar and callbacks.

This function is similar to fit() but provides interactive features: a rich progress bar displayed during optimization, and an optional callback function invoked at each iteration.

Note

This function is not JIT-compatible due to Python side effects in the iteration loop. Use fit() if you need JIT compilation.

Parameters:
  • nll_fn (tp.Callable[[PyTree[V], PyTree], float]) – Callable returning the scalar NLL. It must accept the parameter pytree as its first argument and observation data as its second.

  • params (sl.State[V]) – Initial parameter values organised as a state.

  • observation (PyTree) – Observed data passed to nll_fn. Can be any pytree structure (dict, array, nested containers, etc.).

  • fixed (sl.State[V | EllipsisType] | None) – Optional state of canonicalized keys to fixed values for identifying parameters that should remain unchanged during the fit.

  • bounds (sl.State[ewp.TransformBase] | None) – Optional state of TransformBase instances for parameter bounds.

  • solver (optx.AbstractMinimiser | None) – optimistix.AbstractMinimiser instance to use. Defaults to optimistix.BFGS.

  • max_steps (int) – Maximum number of optimization steps. Defaults to 256.

  • progress (bool) – Whether to display a rich progress bar. Defaults to True.

  • checkpoint_manager (ocp.CheckpointManager | None) – Optional ocp.CheckpointManager that checkpoints during interactive fitting, if checkpoints exist already under the provided path, the fit will automatically continue from the last checkpointed step.

  • callbacks (tp.Iterable[Callback] | None) – Optional function(s) called each iteration with signature (iteration: int, y: State, state: SolverState) -> None. This matches the pattern used by solver.step and solver.terminate. The NLL value can be accessed via state.f_info.f.

  • solver_options (dict[str, tp.Any] | None) – Optional dict of solver-specific options passed to solver.init.

Returns:

FitResult containing the fitted parameters and diagnostics.

Return type:

FitResult[V]

Examples

>>> import everwillow as ew
>>> import everwillow._src.statelib as sl
>>> # Interactive fit with progress bar
>>> def my_nll(params, observation):
...     return (params["mu"] - observation["target"])**2
>>> initial_params = sl.State.from_pytree({"mu": 0.0})
>>> observed = {"target": 2.0}
>>> result = ew.ifit(my_nll, initial_params, observed)
>>> # With custom callback to record history
>>> history = ew.inference.HistoryCallback()
>>> result = ew.ifit(my_nll, initial_params, callbacks=[history])
>>> # Disable progress bar
>>> result = ew.ifit(my_nll, initial_params, progress=False)
Example with Checkpointing:
>>> import orbax.checkpoint as ocp
>>> mngr = ocp.CheckpointManager('/tmp/my_fit')
>>> result = ew.ifit(my_nll, initial_params, max_steps=10, progress=False, checkpoint_manager=mngr)
>>> mngr.wait_until_finished() # checkpointing is async, so let's make sure everything is checkpointed
>>> # run for another 10 steps (increase max_steps to 20); recover from latest checkpoint by reusing the `mngr`:
>>> result = ew.ifit(my_nll, initial_params, max_steps=20, progress=False, checkpoint_manager=mngr)
>>> mngr.wait_until_finished()