Fitting#
Core fitting functionality for statistical inference.
- class everwillow.fitting.FitResult(params, nll, success, solver_result)[source]#
Bases:
Module,Generic[V]Result of a fit operation.
- everwillow.fitting.fit(nll_fn, params, observation, *, fixed=None, bounds=None, solver=None, max_steps=256, **minimise_kwargs)[source]#
Perform an unconditional maximum-likelihood fit.
The negative log-likelihood (NLL) provided via
nll_fnis minimised with respect to all parameters except those explicitly marked as fixed. Internally the parameter pytree is converted into aStateso that subsets of the state can be frozen usingeverwillow._src.statelib.state.partition().Parameter bounds are supported through automatic transformation to unbounded space.
- Parameters:
nll_fn (Callable[[PyTree[~V], PyTree], float]) – Callable returning the scalar NLL. It must accept the parameter pytree as its first argument and observation data as its second.
params (State[V]) – Initial parameter values organised as a state (e.g. mapping or nested containers).
observation (PyTree) – Observed data passed to
nll_fn. Can be any pytree structure (dict, array, nested containers, etc.).fixed (State[V | EllipsisType] | None) – Optional state of canonicalized keys to fixed values for identifying parameters that should remain unchanged during the fit.
bounds (State[TransformBase] | None) – Optional state of
TransformBaseinstances. When provided, parameters are unwrapped via the transform’sunwrapmethod prior to optimisation and wrapped back afterwards.solver (AbstractMinimiser | None) –
optimistix.AbstractMinimiserinstance to use. Defaults tooptimistix.BFGS.max_steps (int) – Maximum number of optimization steps. Defaults to 256.
**minimise_kwargs – Additional keyword arguments forwarded to
optimistix.minimise().
- Returns:
FitResultcontaining the fitted parameters and diagnostics.- Return type:
FitResult[V]
Examples
>>> import everwillow as ew >>> import everwillow._src.statelib as sl
>>> # Basic usage >>> def my_nll(params, observation): ... return (params["mu"] - observation["target"])**2 >>> initial_params = sl.State.from_pytree({"mu": 0.0}) >>> observed = {"target": 2.0} >>> result = ew.fit(my_nll, initial_params, observed) >>> result.params["mu"] # Should be close to 2.0
>>> # Fix 'sigma' while fitting mu >>> def my_nll(params, observation): ... return (params["mu"] - observation["mu_target"]) ** 2 + ( ... params["sigma"] - observation["sigma_target"] ... ) ** 2 >>> initial_params = sl.State.from_pytree({"mu": 0.0, "sigma": 0.5}) >>> observed = {"mu_target": 2.0, "sigma_target": 1.0} >>> fixed = sl.State.from_pytree({"sigma": ...}) >>> result = ew.fit(my_nll, initial_params, observed, fixed=fixed) >>> result.params["sigma"] # Remains fixed 0.5
>>> # With parameter bounds >>> from everwillow._src.parameters.transforms import MinuitTransform >>> bounds = sl.State.from_pytree({"mu": MinuitTransform(lower=0.0, upper=5.0)}) >>> result = ew.fit(my_nll, initial_params, observed, bounds=bounds) >>> 0.0 <= result.params["mu"] <= 5.0 # Respects bounds True
- everwillow.fitting.ifit(nll_fn, params, observation, *, fixed=None, bounds=None, solver=None, max_steps=256, progress=True, checkpoint_manager=None, callbacks=None, solver_options=None)[source]#
Perform an interactive maximum-likelihood fit with progress bar and callbacks.
This function is similar to
fit()but provides interactive features: a rich progress bar displayed during optimization, and an optional callback function invoked at each iteration.Note
This function is not JIT-compatible due to Python side effects in the iteration loop. Use
fit()if you need JIT compilation.- Parameters:
nll_fn (tp.Callable[[PyTree[V], PyTree], float]) – Callable returning the scalar NLL. It must accept the parameter pytree as its first argument and observation data as its second.
params (sl.State[V]) – Initial parameter values organised as a state.
observation (PyTree) – Observed data passed to
nll_fn. Can be any pytree structure (dict, array, nested containers, etc.).fixed (sl.State[V | EllipsisType] | None) – Optional state of canonicalized keys to fixed values for identifying parameters that should remain unchanged during the fit.
bounds (sl.State[ewp.TransformBase] | None) – Optional state of
TransformBaseinstances for parameter bounds.solver (optx.AbstractMinimiser | None) –
optimistix.AbstractMinimiserinstance to use. Defaults tooptimistix.BFGS.max_steps (int) – Maximum number of optimization steps. Defaults to 256.
progress (bool) – Whether to display a rich progress bar. Defaults to True.
checkpoint_manager (ocp.CheckpointManager | None) – Optional ocp.CheckpointManager that checkpoints during interactive fitting, if checkpoints exist already under the provided path, the fit will automatically continue from the last checkpointed step.
callbacks (tp.Iterable[Callback] | None) – Optional function(s) called each iteration with signature
(iteration: int, y: State, state: SolverState) -> None. This matches the pattern used bysolver.stepandsolver.terminate. The NLL value can be accessed viastate.f_info.f.solver_options (dict[str, tp.Any] | None) – Optional dict of solver-specific options passed to
solver.init.
- Returns:
FitResultcontaining the fitted parameters and diagnostics.- Return type:
FitResult[V]
Examples
>>> import everwillow as ew >>> import everwillow._src.statelib as sl
>>> # Interactive fit with progress bar >>> def my_nll(params, observation): ... return (params["mu"] - observation["target"])**2 >>> initial_params = sl.State.from_pytree({"mu": 0.0}) >>> observed = {"target": 2.0} >>> result = ew.ifit(my_nll, initial_params, observed)
>>> # With custom callback to record history >>> history = ew.inference.HistoryCallback() >>> result = ew.ifit(my_nll, initial_params, callbacks=[history])
>>> # Disable progress bar >>> result = ew.ifit(my_nll, initial_params, progress=False)
- Example with Checkpointing:
>>> import orbax.checkpoint as ocp >>> mngr = ocp.CheckpointManager('/tmp/my_fit')
>>> result = ew.ifit(my_nll, initial_params, max_steps=10, progress=False, checkpoint_manager=mngr) >>> mngr.wait_until_finished() # checkpointing is async, so let's make sure everything is checkpointed
>>> # run for another 10 steps (increase max_steps to 20); recover from latest checkpoint by reusing the `mngr`: >>> result = ew.ifit(my_nll, initial_params, max_steps=20, progress=False, checkpoint_manager=mngr) >>> mngr.wait_until_finished()