Tips#

JIT compilation#

All everwillow functions are compatible with jax.jit. Wrapping your workflow in JIT compiles it into optimized XLA code, which can significantly speed up repeated evaluations (e.g. scanning over POI values for limits):

import jax

from everwillow.hypotest.upper_limit import upper_limit


@jax.jit
def compute_limit(observed_n):
    observed = {"n": observed_n}
    calc = AsymptoticCalculator(
        nll_fn=nll,
        params=params,
        observation=observed,
        poi_key="mu",
        predict_fn=predict,
    )

    def cls_objective(poi):
        return calc.cls(calc.test(poi))

    return upper_limit(cls_objective, bounds=(0.0, 5.0), level=0.05)


# First call traces and compiles; subsequent calls are fast
limit = compute_limit(12.0)

JIT compilation works best when the function structure is fixed and only array values change. Avoid passing Python objects that change shape or structure between calls.

64-bit precision#

JAX defaults to 32-bit floats. For statistical inference, 64-bit precision is recommended to avoid numerical issues in likelihood evaluations:

jax.config.update("jax_enable_x64", True)

Set this at the top of your script, before any JAX operations.