Source code for everwillow._src.inference.hypotest.calculators

"""Hypothesis test calculators.

This module provides calculators that orchestrate hypothesis testing by
computing the test statistic, then delegating p-value computation to
Distribution objects.

- ``HypoTestCalculator``: Generic base — forwards all kwargs to the
  test statistic.
- ``AsymptoticCalculator``: Extends the base with Asimov dataset config
  (``predict_fn``/``mu_asimov`` or ``asimov_observation``) for Cowan et al.
  asymptotic workflows.
"""

from __future__ import annotations

import typing as tp

import equinox as eqx
from jaxtyping import Array, PyTree

import everwillow._src.statelib as sl
from everwillow._src.inference.hypotest.distributions import (
    Distribution,
    QTildeAsymptotic,
)
from everwillow._src.inference.hypotest.results import ExpectedBands, HypoTestResult
from everwillow._src.inference.hypotest.test_statistics import QTilde, TestStatistic
from everwillow._src.inference.hypotest.utils import cl_s

__all__ = ["AsymptoticCalculator", "HypoTestCalculator"]


[docs] class HypoTestCalculator(eqx.Module): """Generic hypothesis test calculator. Orchestrates hypothesis testing by: 1. Computing the test statistic on observed data 2. Delegating p-value computation to a Distribution object The calculator stores all model-specific arguments at construction time, so ``test(poi_test)`` only takes the varying parameter. Additional keyword arguments to ``test()`` are forwarded to the test statistic. Attributes: nll_fn: Negative log-likelihood function taking (params, observation). params: Initial parameter state. observation: Observed data passed to nll_fn. poi_key: Canonical key for the parameter of interest, e.g. "mu". test_statistic: Test statistic to use. Defaults to QTilde. distribution: Distribution for p-value computation. Defaults to QTildeAsymptotic. """ nll_fn: tp.Callable[[PyTree, PyTree], float] params: sl.State observation: PyTree poi_key: sl.K test_statistic: TestStatistic = eqx.field(default_factory=QTilde) distribution: Distribution = eqx.field(default_factory=QTildeAsymptotic)
[docs] def test( self, poi_test: float, **kwargs: tp.Any, ) -> HypoTestResult: """Run hypothesis test. Args: poi_test: Test value for the POI. **kwargs: Forwarded to the test statistic. Includes both test-statistic-specific args (e.g. ``predict_fn``, ``mu_asimov`` for Cowan test statistics) and fit options. Returns: HypoTestResult with observed p-values. """ ts_result = self.test_statistic.compute( self.nll_fn, self.params, self.observation, self.poi_key, poi_test, **kwargs, ) pnull = self.distribution.null_pval(ts_result) palt = self.distribution.alt_pval(ts_result) return HypoTestResult( q_obs=ts_result.value, pnull=pnull, palt=palt, test_stat_result=ts_result, )
[docs] def cls(self, result: HypoTestResult) -> Array | None: """Compute CLs = pnull / palt from a hypothesis test result. Args: result: HypoTestResult from ``test()``. Returns: CLs value, or None if either p-value is None. """ if result.pnull is None or result.palt is None: return None return cl_s(result.pnull, result.palt)
[docs] def expected(self, result: HypoTestResult) -> ExpectedBands | None: """Compute expected p-values at standard sigma bands. Delegates to the distribution's ``expected_pvalues`` method. Args: result: HypoTestResult from ``test()``. Returns: ExpectedBands with p-values at each sigma level. Raises: NotImplementedError: If the distribution does not support expected p-value computation. """ return self.distribution.expected_pvalues(result.test_stat_result)
[docs] class AsymptoticCalculator(HypoTestCalculator): """Calculator for Cowan et al. asymptotic hypothesis tests. Extends ``HypoTestCalculator`` with Asimov dataset configuration. These fields are injected into the test statistic call automatically by ``test()``. The Asimov dataset can be provided in two ways: 1. **Pre-computed**: pass ``asimov_observation`` directly. This is useful when the Asimov dataset is expensive to generate or when the model prediction function is not available (e.g. combined models with multiple observation channels). 2. **On-the-fly**: pass ``predict_fn`` and ``mu_asimov``. The Asimov dataset is generated at each ``test()`` call by setting the POI to ``mu_asimov`` and calling ``predict_fn``. When both are provided, ``asimov_observation`` takes precedence and ``predict_fn`` / ``mu_asimov`` are ignored. Example: >>> calc = AsymptoticCalculator( ... nll_fn=nll_fn, params=params, observation=observed, ... poi_key="mu", predict_fn=my_predict_fn, ... ) >>> result = calc.test(poi_test=1.0) Attributes: predict_fn: Function mapping parameter state to expected observation. Used to create the Asimov dataset at ``mu_asimov``. mu_asimov: POI value for Asimov dataset generation. Defaults to 0.0 (background-only, for exclusion tests). Use 1.0 for discovery tests. asimov_observation: Pre-computed Asimov dataset. When provided, this is used directly instead of generating one via ``predict_fn`` / ``mu_asimov``. """ predict_fn: tp.Callable[[sl.State], PyTree] | None = None mu_asimov: float = 0.0 asimov_observation: PyTree | None = None
[docs] def test( self, poi_test: float, **kwargs: tp.Any, ) -> HypoTestResult: """Run asymptotic hypothesis test. Injects ``predict_fn``, ``mu_asimov``, and ``asimov_observation`` from init fields, unless overridden in kwargs. Args: poi_test: Test value for the POI. **kwargs: Additional arguments forwarded to the test statistic (e.g. fit options). Can override ``predict_fn``, ``mu_asimov``, or ``asimov_observation`` for one-off use. Returns: HypoTestResult with observed p-values. """ kwargs.setdefault("predict_fn", self.predict_fn) kwargs.setdefault("mu_asimov", self.mu_asimov) kwargs.setdefault("asimov_observation", self.asimov_observation) return super().test(poi_test, **kwargs)