"""Distributions for converting test statistics to p-values.
Provides asymptotic distribution classes (Cowan et al., arXiv:1007.1727)
and empirical distributions from toy Monte Carlo. Each class exposes
``cdf``, ``null_pval``, and ``alt_pval``.
"""
from __future__ import annotations
import abc
import typing as tp
import warnings
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array
from everwillow._src.inference.hypotest.results import (
BandValues,
ExpectedBands,
TestStatResult,
ToyResult,
)
from everwillow._src.inference.hypotest.utils import (
cl_s,
sigma_from_asimov,
significance,
)
__all__ = [
"Distribution",
"EmpiricalDistribution",
"Q0Asymptotic",
"QMuAsymptotic",
"QTildeAsymptotic",
"SimpleEmpiricalDistribution",
"TMuAsymptotic",
"TMuTildeAsymptotic",
]
_PHI = jax.scipy.stats.norm.cdf
_PPF = jax.scipy.stats.norm.ppf
_BAND_SIGMAS = (-2.0, -1.0, 0.0, 1.0, 2.0)
def _build_expected_bands(
dist: Distribution,
result: TestStatResult,
expected_q_fn: tp.Callable[[float], Array],
) -> ExpectedBands:
"""Build ExpectedBands by evaluating p-values at each sigma fluctuation.
Eagerly computes all derived quantities (CLs, significance) so that
the returned ExpectedBands contains fully populated BandValues.
Args:
dist: Distribution whose null_pval/alt_pval will be called.
result: Original result (used as template for test and q_asimov).
expected_q_fn: Maps band index N to the expected test statistic value.
Returns:
ExpectedBands with BandValues for null_pvalue, alt_pvalue, cl_s,
null_sig, and alt_sig.
"""
pnulls = []
palts = []
for n in _BAND_SIGMAS:
synthetic = TestStatResult(value=expected_q_fn(n), test=result.test, q_asimov=result.q_asimov)
pnulls.append(dist.null_pval(synthetic))
palts.append(dist.alt_pval(synthetic))
null_pvalue = BandValues(*pnulls)
alt_pvalue = BandValues(*palts)
cls_values = BandValues(**{n: cl_s(pn, pa) for (n, pn), (_, pa) in zip(null_pvalue, alt_pvalue, strict=False)})
null_sig = BandValues(**{n: significance(pn) for n, pn in null_pvalue})
alt_sig = BandValues(**{n: significance(pa) for n, pa in alt_pvalue})
return ExpectedBands(
null_pvalue=null_pvalue,
alt_pvalue=alt_pvalue,
cl_s=cls_values,
null_sig=null_sig,
alt_sig=alt_sig,
)
def _require_q_asimov(result: TestStatResult, cls_name: str, pval_type: str) -> bool:
"""Check that q_asimov is available, warn if not.
Returns:
True if q_asimov is present, False otherwise.
"""
if result.q_asimov is None:
warnings.warn(
f"{pval_type} p-value computation in {cls_name} cannot be performed without an Asimov test statistic.",
stacklevel=3,
)
return False
return True
# =============================================================================
# Base Distribution
# =============================================================================
[docs]
class Distribution(eqx.Module):
"""Abstract base for test statistic distributions.
Subclasses must implement:
- ``null_pval``: p-value under null hypothesis
(:math:`\\mu'= \\mu` where :math:`\\mu` is the hypothesis being tested).
- ``alt_pval``: p-value under an alternative hypothesis
(:math:`\\mu'=0` for exclusion, :math:`\\mu'=1` for discovery).
"""
[docs]
@abc.abstractmethod
def null_pval(self, result: TestStatResult) -> Array | None:
r"""p-value under the null hypothesis (:math:`\mu' = \mu`).
Args:
result: Test statistic result.
Returns:
Null p-value, or None if required data (e.g. q_asimov) is missing.
"""
...
[docs]
@abc.abstractmethod
def alt_pval(self, result: TestStatResult) -> Array | None:
"""p-value under an alternative hypothesis.
Args:
result: Test statistic result.
Returns:
Alternative p-value, or None if required data (e.g. q_asimov) is missing.
"""
...
[docs]
def null_significance(self, result: TestStatResult) -> Array | None:
r"""Significance under the null hypothesis: :math:`Z = \Phi^{-1}(1 - p_\text{null})`.
Args:
result: Test statistic result.
Returns:
Significance Z, or None if pnull is None.
"""
pnull = self.null_pval(result)
if pnull is None:
return None
return -_PPF(pnull)
[docs]
def alt_significance(self, result: TestStatResult) -> Array | None:
r"""Significance under the alternative hypothesis: :math:`Z = \Phi^{-1}(1 - p_\text{alt})`.
Args:
result: Test statistic result.
Returns:
Significance Z, or None if palt is None.
"""
palt = self.alt_pval(result)
if palt is None:
return None
return -_PPF(palt)
[docs]
def expected_pvalues(self, result: TestStatResult) -> ExpectedBands | None:
"""Compute expected p-values at standard sigma bands.
Args:
result: Test statistic result.
Returns:
ExpectedBands with (pnull, palt) at each sigma level.
Raises:
NotImplementedError: If the distribution does not support
expected p-value computation.
"""
raise NotImplementedError
# =============================================================================
# Asymptotic Distributions (Cowan et al. formulas)
# =============================================================================
[docs]
class TMuAsymptotic(Distribution):
r"""Asymptotic distribution for :math:`t_\mu` (two-sided, Eq. 38).
Used with the :math:`t_\mu` test statistic for two-sided confidence intervals.
"""
[docs]
def cdf(self, q: Array, mu: Array, mu_prime: Array, sigma: Array) -> Array:
r"""CDF: :math:`F(t_\mu \mid \mu') = \Phi(\sqrt{t} + \frac{\mu-\mu'}{\sigma})
+ \Phi(\sqrt{t} - \frac{\mu-\mu'}{\sigma}) - 1`."""
sqrt_q = jnp.sqrt(jnp.maximum(q, 0.0))
delta = (mu - mu_prime) / sigma
return _PHI(sqrt_q + delta) + _PHI(sqrt_q - delta) - 1.0
[docs]
def null_pval(self, result: TestStatResult) -> Array:
r"""Null p-value: :math:`p = 2(1 - \Phi(\sqrt{t_\mu}))`. No :math:`\sigma` needed."""
sqrt_q = jnp.sqrt(jnp.maximum(result.value, 0.0))
return 2.0 * (1.0 - _PHI(sqrt_q))
[docs]
def alt_pval(self, result: TestStatResult) -> Array | None:
r"""Alt p-value: :math:`p = 2 - \Phi(\sqrt{t} + \sqrt{q_A}) - \Phi(\sqrt{t} - \sqrt{q_A})`.
:math:`q_A = \mu^2/\sigma^2` (Asimov under :math:`\mu'=0`),
so :math:`\sqrt{q_A} = \mu/\sigma = (\mu-\mu')/\sigma`.
"""
if not _require_q_asimov(result, self.__class__.__name__, "Alternative"):
return None
sqrt_q = jnp.sqrt(jnp.maximum(result.value, 0.0))
sqrt_qa = jnp.sqrt(jnp.maximum(result.q_asimov, 0.0))
return 2.0 - _PHI(sqrt_q + sqrt_qa) - _PHI(sqrt_q - sqrt_qa)
[docs]
class TMuTildeAsymptotic(Distribution):
r"""Asymptotic distribution for :math:`\tilde{t}_\mu` (two-sided with physical bound, Eq. 40/44).
Used with the :math:`\tilde{t}_\mu` test statistic for two-sided tests with the
physical constraint :math:`\mu \geq 0`. The CDF has a piecewise structure with the
:math:`\Phi + \Phi - 1` form in both regions (Eq. 44).
"""
[docs]
def cdf(self, q: Array, mu: Array, mu_prime: Array, sigma: Array) -> Array:
r"""CDF: :math:`F(\tilde{t}_\mu \mid \mu')` — piecewise at threshold :math:`\mu^2/\sigma^2`."""
sqrt_q = jnp.sqrt(jnp.maximum(q, 0.0))
delta = (mu - mu_prime) / sigma
threshold = (mu / sigma) ** 2
# Standard region: Φ(√t̃ + δ) + Φ(√t̃ - δ) - 1
f_standard = _PHI(sqrt_q + delta) + _PHI(sqrt_q - delta) - 1.0
# Boundary region: Φ(√t̃ + δ) + Φ((t̃ + μ²/σ²)/(2μ/σ) - δ) - 1
f_boundary = _PHI(sqrt_q + delta) + _PHI((q + threshold) / (2.0 * mu / sigma) - delta) - 1.0
return jnp.where(q <= threshold, f_standard, f_boundary)
[docs]
def null_pval(self, result: TestStatResult) -> Array | None:
r"""Null p-value (:math:`\mu' = \mu`), where :math:`q_A = \mu^2/\sigma^2`.
.. math::
p_{\mu'=\mu} = \begin{cases}
2\bigl(1 - \Phi(\sqrt{\tilde{t}})\bigr)
& \text{if } \tilde{t} \leq q_A \\
2 - \Phi(\sqrt{\tilde{t}})
- \Phi\!\left(\frac{\tilde{t} + q_A}{2\sqrt{q_A}}\right)
& \text{if } \tilde{t} > q_A
\end{cases}
"""
if not _require_q_asimov(result, self.__class__.__name__, "Null"):
return None
q = result.value
q_asimov = result.q_asimov
sqrt_q = jnp.sqrt(jnp.maximum(q, 0.0))
sqrt_qa = jnp.sqrt(jnp.maximum(q_asimov, 0.0))
p_standard = 2.0 * (1.0 - _PHI(sqrt_q))
p_boundary = 2.0 - _PHI(sqrt_q) - _PHI((q + q_asimov) / (2.0 * sqrt_qa))
return jnp.where(q <= q_asimov, p_standard, p_boundary)
[docs]
def alt_pval(self, result: TestStatResult) -> Array | None:
r"""Alt p-value (:math:`\mu' = 0`), where :math:`q_A = \mu^2/\sigma^2`.
.. math::
p_{\mu'=0} = \begin{cases}
2 - \Phi(\sqrt{\tilde{t}} + \sqrt{q_A})
- \Phi(\sqrt{\tilde{t}} - \sqrt{q_A})
& \text{if } \tilde{t} \leq q_A \\
2 - \Phi(\sqrt{\tilde{t}} + \sqrt{q_A})
- \Phi\!\left(\frac{\tilde{t} - q_A}{2\sqrt{q_A}}\right)
& \text{if } \tilde{t} > q_A
\end{cases}
"""
if not _require_q_asimov(result, self.__class__.__name__, "Alternative"):
return None
q = result.value
q_asimov = result.q_asimov
sqrt_q = jnp.sqrt(jnp.maximum(q, 0.0))
sqrt_qa = jnp.sqrt(jnp.maximum(q_asimov, 0.0))
p_standard = 2.0 - _PHI(sqrt_q + sqrt_qa) - _PHI(sqrt_q - sqrt_qa)
p_boundary = 2.0 - _PHI(sqrt_q + sqrt_qa) - _PHI((q - q_asimov) / (2.0 * sqrt_qa))
return jnp.where(q <= q_asimov, p_standard, p_boundary)
[docs]
class Q0Asymptotic(Distribution):
r"""Asymptotic distribution for :math:`q_0` (discovery, Eq. 49).
Used with the :math:`q_0` test statistic for discovery significance.
"""
[docs]
def cdf(self, q: Array, mu: Array, mu_prime: Array, sigma: Array) -> Array:
r"""CDF: :math:`F(q_0 \mid \mu') = \Phi(\sqrt{q_0} - \mu'/\sigma)`."""
sqrt_q = jnp.sqrt(jnp.maximum(q, 0.0))
return _PHI(sqrt_q - mu_prime / sigma)
[docs]
def null_pval(self, result: TestStatResult) -> Array:
r"""Null p-value: :math:`p = 1 - \Phi(\sqrt{q_0})`. No :math:`\sigma` needed."""
sqrt_q = jnp.sqrt(jnp.maximum(result.value, 0.0))
return 1.0 - _PHI(sqrt_q)
[docs]
def alt_pval(self, result: TestStatResult) -> Array | None:
r"""Alt p-value: :math:`p = 1 - \Phi(\sqrt{q_0} - \sqrt{q_A})`.
:math:`q_A = \mu_\text{asimov}^2/\sigma^2` (Asimov under signal),
so :math:`\sqrt{q_A} = \mu_\text{asimov}/\sigma`.
"""
if not _require_q_asimov(result, self.__class__.__name__, "Alternative"):
return None
sqrt_q = jnp.sqrt(jnp.maximum(result.value, 0.0))
sqrt_qa = jnp.sqrt(jnp.maximum(result.q_asimov, 0.0))
return 1.0 - _PHI(sqrt_q - sqrt_qa)
[docs]
def expected_pvalues(self, result: TestStatResult) -> ExpectedBands | None:
r"""Expected p-values at :math:`\pm N\sigma` fluctuations under signal hypothesis.
:math:`q_A = \mu_\text{asimov}^2/\sigma^2` (Asimov under signal),
so :math:`\sqrt{q_A} = \mu_\text{asimov}/\sigma`.
:math:`q = \max(0, \sqrt{q_A} + N)^2`. Upward fluctuations (:math:`+N`) increase
discovery significance, opposite to exclusion tests.
Args:
result: Must contain ``q_asimov`` for :math:`\sqrt{q_A}`.
Returns:
ExpectedBands with (pnull, palt) at each sigma level,
or None if q_asimov is missing.
"""
if not _require_q_asimov(result, self.__class__.__name__, "Expected"):
return None
sqrt_qa = jnp.sqrt(jnp.maximum(result.q_asimov, 0.0))
def expected_q_fn(n: float) -> Array:
return jnp.maximum(sqrt_qa + n, 0.0) ** 2
return _build_expected_bands(self, result, expected_q_fn)
[docs]
class QMuAsymptotic(Distribution):
r"""Asymptotic distribution for :math:`q_\mu` (upper limit, Eq. 57).
Used with the :math:`q_\mu` test statistic for upper limit calculations.
"""
[docs]
def cdf(self, q: Array, mu: Array, mu_prime: Array, sigma: Array) -> Array:
r"""CDF: :math:`F(q_\mu \mid \mu') = \Phi(\sqrt{q_\mu} - (\mu - \mu')/\sigma)`."""
sqrt_q = jnp.sqrt(jnp.maximum(q, 0.0))
return _PHI(sqrt_q - (mu - mu_prime) / sigma)
[docs]
def null_pval(self, result: TestStatResult) -> Array:
r"""Null p-value: :math:`p = 1 - \Phi(\sqrt{q_\mu})`. No :math:`\sigma` needed."""
sqrt_q = jnp.sqrt(jnp.maximum(result.value, 0.0))
return 1.0 - _PHI(sqrt_q)
[docs]
def alt_pval(self, result: TestStatResult) -> Array | None:
r"""Alt p-value: :math:`p = 1 - \Phi(\sqrt{q_\mu} - \sqrt{q_A})`.
:math:`q_A = \mu^2/\sigma^2` (Asimov under :math:`\mu'=0`),
so :math:`\sqrt{q_A} = \mu/\sigma`.
"""
if not _require_q_asimov(result, self.__class__.__name__, "Alternative"):
return None
sqrt_q = jnp.sqrt(jnp.maximum(result.value, 0.0))
sqrt_qa = jnp.sqrt(jnp.maximum(result.q_asimov, 0.0))
return 1.0 - _PHI(sqrt_q - sqrt_qa)
[docs]
def expected_pvalues(self, result: TestStatResult) -> ExpectedBands | None:
r"""Expected p-values at :math:`\pm N\sigma` fluctuations under background-only.
:math:`q_A = \mu^2/\sigma^2` (Asimov under :math:`\mu'=0`),
so :math:`\sqrt{q_A} = \mu/\sigma`.
At band :math:`N`, the expected :math:`\hat{\mu} = N\sigma`, giving
:math:`\sqrt{q} = \max(0, \mu/\sigma - N)`.
Synthetic TestStatResult objects are passed through the existing
null_pval/alt_pval methods to reuse the CDF logic.
Args:
result: Must contain ``q_asimov`` for :math:`\sigma` extraction.
Returns:
ExpectedBands with (pnull, palt) at each sigma level,
or None if q_asimov is missing.
"""
if not _require_q_asimov(result, self.__class__.__name__, "Expected"):
return None
sigma = sigma_from_asimov(result.test, result.q_asimov)
# Guard: at poi=0, sigma=0 → mu/sigma = 0/0 = NaN.
# Use 0 instead: all expected q become 0, giving CLs=1.0.
mu_over_sigma = jnp.where(sigma > 0, result.test / sigma, 0.0)
def expected_q_fn(n: float) -> Array:
return jnp.maximum(mu_over_sigma - n, 0.0) ** 2
return _build_expected_bands(self, result, expected_q_fn)
[docs]
class QTildeAsymptotic(Distribution):
r"""Asymptotic distribution for :math:`\tilde{q}_\mu` (upper limit with physical bound, Eq. 64).
Used with the :math:`\tilde{q}_\mu` test statistic for hypothesis testing with the
physical constraint :math:`\mu \geq 0`. The CDF is piecewise at
:math:`\tilde{q} = \mu^2/\sigma^2 = q_\text{asimov}`.
"""
[docs]
def cdf(self, q: Array, mu: Array, mu_prime: Array, sigma: Array) -> Array:
r"""CDF: :math:`F(\tilde{q}_\mu \mid \mu')` — piecewise at threshold :math:`\mu^2/\sigma^2`."""
sqrt_q = jnp.sqrt(jnp.maximum(q, 0.0))
threshold = (mu / sigma) ** 2
# Standard region: Φ(√q̃ - (μ-μ')/σ)
f_standard = _PHI(sqrt_q - (mu - mu_prime) / sigma)
# Boundary region: Φ((q̃ - (μ²-2μμ')/σ²) / (2μ/σ))
f_boundary = _PHI((q - (mu**2 - 2 * mu * mu_prime) / sigma**2) / (2.0 * mu / sigma))
return jnp.where(q <= threshold, f_standard, f_boundary)
[docs]
def null_pval(self, result: TestStatResult) -> Array | None:
r"""Null p-value (:math:`\mu' = \mu`), where :math:`q_A = \mu^2/\sigma^2`.
.. math::
p_{\mu'=\mu} = \begin{cases}
1 - \Phi(\sqrt{\tilde{q}})
& \text{if } \tilde{q} \leq q_A \\
1 - \Phi\!\left(\frac{\tilde{q} + q_A}{2\sqrt{q_A}}\right)
& \text{if } \tilde{q} > q_A
\end{cases}
"""
if not _require_q_asimov(result, self.__class__.__name__, "Null"):
return None
q = result.value
q_asimov = result.q_asimov
sqrt_q = jnp.sqrt(jnp.maximum(q, 0.0))
sqrt_qa = jnp.sqrt(jnp.maximum(q_asimov, 0.0))
p_standard = 1.0 - _PHI(sqrt_q)
p_boundary = 1.0 - _PHI((q + q_asimov) / (2.0 * sqrt_qa))
return jnp.where(q <= q_asimov, p_standard, p_boundary)
[docs]
def alt_pval(self, result: TestStatResult) -> Array | None:
r"""Alt p-value (:math:`\mu' = 0`), where :math:`q_A = \mu^2/\sigma^2`.
.. math::
p_{\mu'=0} = \begin{cases}
1 - \Phi(\sqrt{\tilde{q}} - \sqrt{q_A})
& \text{if } \tilde{q} \leq q_A \\
1 - \Phi\!\left(\frac{\tilde{q} - q_A}{2\sqrt{q_A}}\right)
& \text{if } \tilde{q} > q_A
\end{cases}
"""
if not _require_q_asimov(result, self.__class__.__name__, "Alternative"):
return None
q = result.value
q_asimov = result.q_asimov
sqrt_q = jnp.sqrt(jnp.maximum(q, 0.0))
sqrt_qa = jnp.sqrt(jnp.maximum(q_asimov, 0.0))
p_standard = 1.0 - _PHI(sqrt_q - sqrt_qa)
p_boundary = 1.0 - _PHI((q - q_asimov) / (2.0 * sqrt_qa))
return jnp.where(q <= q_asimov, p_standard, p_boundary)
[docs]
def expected_pvalues(self, result: TestStatResult) -> ExpectedBands | None:
r"""Expected p-values at :math:`\pm N\sigma` fluctuations under background-only.
At band :math:`N`, :math:`\hat{\mu} = N\sigma`, so the expected test
statistic is (with :math:`q_A = \mu^2/\sigma^2`):
.. math::
\tilde{q}_\text{exp} = \begin{cases}
\max(0,\; \mu/\sigma - N)^2
& \text{if } N \geq 0 \\
(\mu/\sigma)^2 - 2(\mu/\sigma)\,N
& \text{if } N < 0
\end{cases}
Args:
result: Must contain ``q_asimov`` for :math:`\sigma` extraction.
Returns:
ExpectedBands with (pnull, palt) at each sigma level,
or None if q_asimov is missing.
"""
if not _require_q_asimov(result, self.__class__.__name__, "Expected"):
return None
sigma = sigma_from_asimov(result.test, result.q_asimov)
# Guard: at poi=0, sigma=0 → mu/sigma = 0/0 = NaN.
# Use 0 instead: all expected q become 0, giving CLs=1.0.
mu_over_sigma = jnp.where(sigma > 0, result.test / sigma, 0.0)
def expected_q_fn(n: float) -> Array:
standard = jnp.maximum(mu_over_sigma - n, 0.0) ** 2
boundary = mu_over_sigma**2 - 2.0 * mu_over_sigma * n
# q̃ is piecewise in μ̂: standard for μ̂ ≥ 0, boundary for μ̂ < 0.
# At band N, μ̂ = Nσ, so μ̂ ≥ 0 ⟺ N ≥ 0.
return jnp.where(n >= 0, standard, boundary)
return _build_expected_bands(self, result, expected_q_fn)
# =============================================================================
# Empirical Distribution (from toys)
# =============================================================================
[docs]
class EmpiricalDistribution(Distribution):
"""Base class for distributions built from toy test statistics.
Stores the raw test statistic arrays from toy generation and provides
the ``from_toys`` factory method. Subclass this and override
``null_pval`` / ``alt_pval`` to implement custom p-value computation
methods (e.g. KDE smoothing, tail extrapolation).
Attributes:
q_null: Test statistics under the tested hypothesis (poi_test).
q_alt: Test statistics under the alternative hypothesis (poi_alt).
None if poi_alt was not provided to the ToyGenerator.
"""
q_null: Array
q_alt: Array | None = None
[docs]
@classmethod
def from_toys(cls, toys: ToyResult) -> EmpiricalDistribution:
"""Construct from a ToyResult.
Args:
toys: Raw toy generation output containing q_null and optionally q_alt.
Returns:
An instance of this distribution class.
"""
return cls(q_null=toys.q_null, q_alt=toys.q_alt)
[docs]
class SimpleEmpiricalDistribution(EmpiricalDistribution):
r"""Empirical p-values via simple tail counting.
:math:`p_\text{null} = \text{fraction of } q_\text{null} \geq q_\text{obs}`,
:math:`p_\text{alt} = \text{fraction of } q_\text{alt} \geq q_\text{obs}`.
"""
[docs]
def null_pval(self, result: TestStatResult) -> Array:
r"""Empirical p-value under tested hypothesis: fraction of :math:`q_\text{null} \geq q_\text{obs}`."""
return jnp.mean((self.q_null >= result.value).astype(self.q_null.dtype))
[docs]
def alt_pval(self, result: TestStatResult) -> Array | None:
r"""Empirical p-value under alternative: fraction of :math:`q_\text{alt} \geq q_\text{obs}`.
Returns None if q_alt was not provided (no alternative toys generated).
"""
if self.q_alt is None:
warnings.warn(
"Alternative p-value computation in SimpleEmpiricalDistribution "
"cannot be performed without q_alt toys. "
"Generate toys with poi_alt to compute alternative p-values.",
stacklevel=2,
)
return None
return jnp.mean((self.q_alt >= result.value).astype(self.q_alt.dtype))
[docs]
def expected_pvalues(self, result: TestStatResult) -> ExpectedBands:
"""Compute expected p-values at standard sigma bands using toy quantiles.
Uses quantiles of q_alt at standard sigma percentiles as synthetic
test statistic values, then evaluates empirical p-values at each
via ``_build_expected_bands``.
Args:
result: Test statistic result (used as template for synthetic results).
Returns:
ExpectedBands with empirical p-values at each sigma level.
Raises:
ValueError: If q_alt is None (no alternative toys generated).
"""
if self.q_alt is None:
msg = "expected_pvalues requires q_alt toys. Generate toys with poi_alt to use this method."
raise ValueError(msg)
# Standard sigma percentiles: Φ(N) for N in (-2, -1, 0, 1, 2)
percentiles = jnp.array([_PHI(n) for n in _BAND_SIGMAS])
q_quantiles = jnp.quantile(self.q_alt, percentiles)
# Map band index (position in _BAND_SIGMAS) to the quantile value
def expected_q_fn(n: float) -> Array:
idx = _BAND_SIGMAS.index(n)
return q_quantiles[idx]
return _build_expected_bands(self, result, expected_q_fn)