Features#

Interactive fitting#

ifit works like fit but adds a live progress bar showing the current NLL at each step. It also accepts callbacks and a checkpoint manager for saving and resuming fits.

import everwillow as ew
import everwillow.statelib as sl

result = ew.ifit(nll_fn, params, observation, max_steps=200)

Because ifit uses a Python loop with side effects, it is not JIT-compatible. Use fit when you need JIT compilation.

Callbacks#

Pass a list of callbacks to ifit to run custom logic at each iteration. A callback is any callable with signature (iteration: int, y: State, state: SolverState) -> None. The current NLL is available via state.f_info.f.

The built-in HistoryCallback records step indices and NLL values for plotting convergence:

from everwillow.callback import HistoryCallback

history = HistoryCallback()
result = ew.ifit(nll_fn, params, observation, callbacks=[history])

# Plot convergence
import matplotlib.pyplot as plt

plt.plot(history.steps, history.nlls)
plt.xlabel("Step")
plt.ylabel("NLL")

Any callable matching the signature works as a custom callback. For example, to print the NLL every 10 steps:

def print_every_10(iteration, y, state):
    if iteration % 10 == 0:
        print(f"Step {iteration}: NLL = {state.f_info.f:.4f}")


result = ew.ifit(nll_fn, params, observation, callbacks=[history, print_every_10])

Checkpointing#

ifit integrates with orbax for checkpointing. Pass a CheckpointManager and the fit state is saved at each step. Reusing the same manager automatically resumes from the last saved step:

import orbax.checkpoint as ocp

mngr = ocp.CheckpointManager("/tmp/my_fit")

# Run for 50 steps
result = ew.ifit(nll_fn, params, observation, max_steps=50, checkpoint_manager=mngr)
mngr.wait_until_finished()  # checkpointing is async

# Resume and run to 100 steps total
result = ew.ifit(nll_fn, params, observation, max_steps=100, checkpoint_manager=mngr)
mngr.wait_until_finished()

State visualization#

Call show() on any State to get a pretty-printed view with rich array formatting. In Jupyter notebooks this produces an interactive, foldable display; in terminals it prints formatted text:

import jax.numpy as jnp
import everwillow.statelib as sl

state = sl.State.from_pytree({"mu": 1.0, "sigma": jnp.array([0.1, 0.2])})
state.show()
# State({
#   'mu': 1.0,
#   'sigma': <jax.Array([0.1, 0.2], dtype=float32)>,
# })