Source code for everwillow._src.inference.hypotest.upper_limit

"""Upper limit finding via root search.

This module provides generic root-finding functions for computing upper limits.
The functions are criterion-agnostic: the user provides an objective function
that maps POI values to some quantity, and the root finder locates where
that quantity equals a target level.

Both implementations are pure JAX and fully JIT-compatible:
    - upper_limit: Uses optimistix.Bisection for deterministic objectives
    - upper_limit_toys: Uses jax.lax.while_loop for stochastic (toy-based) objectives

Example usage:
    >>> # CLs-based upper limit (asymptotic)
    >>> def cls_objective(poi):
    ...     return calc.cls(calc.test(poi))
    >>> limit = upper_limit(cls_objective, bounds=(0, 5), level=0.05)

    >>> # CLs-based upper limit (toy-based)
    >>> def cls_objective_toys(poi, key):
    ...     result = toy_calc.test(poi, key=key)
    ...     return toy_calc.cls(result)
    >>> limit = upper_limit_toys(cls_objective_toys, bounds=(0, 5), key=key)
"""

from __future__ import annotations

import typing as tp

import equinox as eqx
import jax
import jax.numpy as jnp
import optimistix as optx
from jaxtyping import Array, PRNGKeyArray

from everwillow._src.inference.hypotest.results import BandValues

__all__ = [
    "expected_upper_limit",
    "upper_limit",
    "upper_limit_scan",
    "upper_limit_toys",
]


[docs] def upper_limit( objective_fn: tp.Callable[[float], Array], bounds: tuple[float, float], level: float = 0.05, *, solver: optx.AbstractRootFinder | None = None, rtol: float = 1e-4, atol: float = 1e-6, max_steps: int = 100, ) -> Array: """Find POI value where objective function equals target level. Uses root-finding to find where objective_fn(poi) = level. Pure JAX implementation via optimistix, fully JIT-compatible. This is a generic root finder - the user composes the objective function to implement their desired exclusion criterion (CLs, p_alt, etc.). Args: objective_fn: Function mapping POI value to quantity of interest. Must be JAX-traceable (no float() calls on traced values). Should be monotonic within bounds for reliable convergence. bounds: (lower, upper) search range for POI value. level: Target value for the objective function (default 0.05). solver: Root-finding solver to use. Defaults to optx.Bisection. rtol: Relative tolerance for convergence (used by default solver). atol: Absolute tolerance for convergence (used by default solver). max_steps: Maximum iterations. Returns: POI value where objective_fn(poi) = level. Note: The objective function is JIT-compiled by optimistix. Avoid calling float() or other Python operations that break JAX tracing. Examples: >>> # Find where CLs = 0.05 (95% CL upper limit) >>> def cls_objective(poi): ... return calc.cls(calc.test(poi)) >>> limit = upper_limit(cls_objective, bounds=(0, 5), level=0.05) >>> # With custom solver >>> solver = optx.Newton(rtol=1e-5, atol=1e-5) >>> limit = upper_limit(cls_objective, bounds=(0, 5), solver=solver) """ def root_objective(poi, _args): """Objective for root finding: f(poi) - level = 0.""" return objective_fn(poi) - level if solver is None: solver = optx.Bisection(rtol=rtol, atol=atol) # type: ignore[call-arg] # Initial guess at midpoint y0 = jnp.array((bounds[0] + bounds[1]) / 2.0) solution: optx.Solution = optx.root_find( root_objective, solver, y0, args=None, options={"lower": bounds[0], "upper": bounds[1]}, max_steps=max_steps, throw=False, ) return solution.value
[docs] def upper_limit_toys( objective_fn: tp.Callable[[float, PRNGKeyArray], Array], bounds: tuple[float, float], key: PRNGKeyArray, level: float = 0.05, *, tol: float = 1e-2, max_iterations: int = 100, ) -> Array: """Find POI value where stochastic objective equals target level. Uses bisection search with fresh PRNG keys at each iteration. Pure JAX implementation via jax.lax.while_loop, fully JIT-compatible. Args: objective_fn: Function mapping (poi, key) to quantity of interest. Should be monotonic (typically decreasing) as POI increases. Must be JAX-traceable (no float() calls on traced values). bounds: (lower, upper) search range for POI value. key: JAX PRNG key for reproducibility. level: Target value for the objective function (default 0.05). tol: Stop when objective is within tol of level (default 1e-2). max_iterations: Maximum bisection iterations (default 100). Returns: POI value where objective_fn(poi, key) ≈ level. Note: The result has statistical uncertainty from Monte Carlo sampling. The tolerance should account for this. Examples: >>> # CLs-based upper limit with toys >>> def cls_objective(poi, key): ... result = toy_calc( ... nll_fn, params, "mu", poi, ... sample_fn=sample_fn, nll_factory=nll_factory, key=key ... ) ... return result.cl_s >>> limit = upper_limit_toys(cls_objective, bounds=(0, 5), key=key) """ def cond_fn(state): iteration, _lo, _hi, converged = state return (iteration < max_iterations) & ~converged def body_fn(state): iteration, lo, hi, _ = state mid = (lo + hi) / 2.0 # Fresh key for this iteration key_iter = jax.random.fold_in(key, iteration) # Evaluate objective at midpoint obj_mid = objective_fn(mid, key_iter) # Check convergence (objective close to target level) converged = jnp.abs(obj_mid - level) < tol # Bisection: objective typically decreases as POI increases new_lo = jnp.where(obj_mid > level, mid, lo) new_hi = jnp.where(obj_mid > level, hi, mid) return (iteration + 1, new_lo, new_hi, converged) # Initial state: (iteration, lo, hi, converged) lo0 = jnp.asarray(bounds[0]) hi0 = jnp.asarray(bounds[1]) init_state = (0, lo0, hi0, jnp.array(False)) final_state = jax.lax.while_loop(cond_fn, body_fn, init_state) _, lo, hi, _ = final_state result = (lo + hi) / 2.0 at_lower = jnp.isclose(result, lo0, rtol=tol) at_upper = jnp.isclose(result, hi0, rtol=tol) result = eqx.error_if( result, at_lower | at_upper, "upper_limit_toys: root not found within bounds. " "The limit is at the search boundary, suggesting the bounds are too narrow.", ) return result # noqa: RET504
[docs] def upper_limit_scan( objective_fn: tp.Callable[[float], Array], scan: Array, level: float = 0.05, ) -> Array: """Find POI value where objective equals target level via grid scan. Evaluates the objective function on a grid of POI values, then interpolates to find where it crosses the target level. Fully JIT-compatible via jax.vmap and jnp.interp. This is useful when: - The objective function is expensive and you want to reuse evaluations - You need to visualize the objective curve - Root-finding fails due to non-monotonicity Args: objective_fn: Function mapping POI value to quantity of interest. Must be JAX-traceable. scan: Array of POI values to evaluate. Should be monotonically increasing and span the expected limit location. level: Target value for the objective function (default 0.05). Returns: POI value where objective_fn(poi) = level, found by interpolation. Note: The accuracy depends on the density of scan points near the crossing. For CLs limits, the objective typically decreases as POI increases. Examples: >>> # Scan CLs on a grid >>> scan = jnp.linspace(0, 2, 50) >>> limit = upper_limit_scan( ... lambda poi: calc.cls(calc.test(poi)), ... scan, ... level=0.05, ... ) >>> # With finer grid near expected limit >>> scan = jnp.concatenate([ ... jnp.linspace(0, 0.5, 10), ... jnp.linspace(0.5, 1.5, 30), ... jnp.linspace(1.5, 3, 10), ... ]) >>> limit = upper_limit_scan(cls_objective, scan, level=0.05) """ # Evaluate objective on the scan grid values = jax.vmap(objective_fn)(scan) # Interpolate to find crossing point # For CLs: objective decreases as POI increases, so reverse for interp # jnp.interp expects xp to be increasing, so we reverse both arrays result = jnp.interp(level, values[::-1], scan[::-1]) at_lower = jnp.isclose(result, scan[0]) at_upper = jnp.isclose(result, scan[-1]) result = eqx.error_if( result, at_lower | at_upper, "upper_limit_scan: root not found within scan range. " "The limit is at the scan boundary, suggesting the scan range is too narrow.", ) return result # noqa: RET504
[docs] def expected_upper_limit( band_objective_fn: tp.Callable[[float], BandValues], bounds: tuple[float, float], level: float = 0.05, **solver_kwargs: tp.Any, ) -> BandValues: """Find expected upper limits at each sigma band. Calls :func:`upper_limit` five times — once per band — extracting the corresponding scalar from the ``BandValues`` returned by ``band_objective_fn``. Args: band_objective_fn: Function mapping POI value to ``BandValues`` of the objective quantity (e.g. expected CLs at each sigma band). bounds: (lower, upper) search range for POI value. level: Target level (default 0.05 for 95% CL). **solver_kwargs: Additional arguments passed to :func:`upper_limit` (e.g., ``rtol``, ``atol``, ``max_steps``). Returns: BandValues where each entry is the upper limit at that sigma band. Example: >>> def band_cls_objective(poi): ... result = calc.test(poi) ... bands = calc.expected(result) ... return bands.cl_s >>> brazil = expected_upper_limit(band_cls_objective, bounds=(0, 5)) >>> for name, val in brazil: ... print(f" {name}: {float(val):.4f}") """ band_limits = {} for band_name in BandValues._NAMES: def _band_fn(poi: float, _name: str = band_name) -> Array: bands = band_objective_fn(poi) if bands is None: msg = ( # type: ignore[unreachable] f"expected_upper_limit: 'band_objective_fn' returned None " f"for POI value {poi!r}; expected a BandValues instance." ) raise ValueError(msg) try: return bands[_name] except (KeyError, TypeError, AttributeError) as exc: msg = f"expected_upper_limit: 'band_objective_fn' returned an \ object that does not provide the requested band \ {_name!r}. Ensure it returns a valid BandValues with \ all expected bands." raise ValueError(msg) from exc band_limits[band_name] = upper_limit( _band_fn, bounds, level, **solver_kwargs, ) return BandValues(**band_limits)