Interoperability with existing statistical tools in HEP

Interoperability with existing statistical tools in HEP#

This tutorial walks through a simple counting experiment implemented with three JAX-supporting modelling libraries: pyhs3, evermore, and pyhf. To run the examples locally use uv with the examples dependency group, for instance uv run --group examples python docs/python/test_pyhs3_example.py. Each script:

  1. Builds the likelihood with the library’s native abstractions.

  2. Exposes a pytree of parameters that everwillow can optimise.

  3. Calls everwillow.fit() to obtain the best-fit values.

The counting model#

We use a single-bin signal-plus-background measurement with one shared shape modifier and two log-normal normalisation nuisances.

  • Observed events: 37.

  • Signal template: \(s = 3\) events, scaled by the strength parameter \(\mu\).

  • Background templates: \(b_1 = 10\) and \(b_2 = 20\) with shape variations \(b_1^{\mathrm{up}} = 12\), \(b_1^{\mathrm{down}} = 8\) and \(b_2^{\mathrm{up}} = 23\), \(b_2^{\mathrm{down}} = 19\).

  • Log-normal rate modifiers with unit-width Gaussian controls: \(\theta_{\text{norm1}}\) and \(\theta_{\text{norm2}}\).

  • A single shape nuisance \(\theta_{\text{shape}}\) shared by both backgrounds.

The total expectation in the Poisson term is

\[\lambda(\mu, \theta) = \mu s + e^{\log(1.1)\,\theta_{\text{norm1}}}\Bigl(b_1 + \max(0,\,\theta_{\text{shape}})(b_1^{\mathrm{up}}-b_1) + \min(0,\,\theta_{\text{shape}})(b_1-b_1^{\mathrm{down}})\Bigr) + e^{\log(1.05)\,\theta_{\text{norm2}}}\Bigl(b_2 + \max(0,\,\theta_{\text{shape}})(b_2^{\mathrm{up}}-b_2) + \min(0,\,\theta_{\text{shape}})(b_2-b_2^{\mathrm{down}})\Bigr).\]

The negative log-likelihood combines the Poisson term with standard Gaussian constraints for \(\theta_{\text{norm1}}\), \(\theta_{\text{norm2}}\), and \(\theta_{\text{shape}}\). The tabs below show the runnable examples in each toolkit; run them with uv run --group examples to reproduce the fits.

"""Standalone pyhf counting experiment example."""

from __future__ import annotations

import jax
import jax.numpy as jnp
import pyhf

import everwillow as ew
import everwillow.statelib as sl

jax.config.update("jax_enable_x64", True)
pyhf.set_backend("jax")

# Build the workspace
spec = {
    "channels": [
        {
            "name": "singlebin",
            "samples": [
                {
                    "name": "signal",
                    "data": [3.0],
                    "modifiers": [{"name": "mu", "type": "normfactor", "data": None}],
                },
                {
                    "name": "bkg1",
                    "data": [10.0],
                    "modifiers": [
                        {
                            "name": "norm1",
                            "type": "normsys",
                            "data": {"hi": 1.1, "lo": 0.9},
                        },
                        {
                            "name": "shape1",
                            "type": "histosys",
                            "data": {"hi_data": [12.0], "lo_data": [8.0]},
                        },
                    ],
                },
                {
                    "name": "bkg2",
                    "data": [20.0],
                    "modifiers": [
                        {
                            "name": "norm2",
                            "type": "normsys",
                            "data": {"hi": 1.05, "lo": 0.95},
                        },
                        {
                            "name": "shape1",
                            "type": "histosys",
                            "data": {"hi_data": [23.0], "lo_data": [19.0]},
                        },
                    ],
                },
            ],
        }
    ],
    "observations": [{"name": "singlebin", "data": [37.0]}],
    "measurements": [
        {
            "name": "Measurement",
            "config": {
                "poi": "mu",
                "parameters": [],
            },
        }
    ],
    "version": "1.0.0",
}

workspace = pyhf.Workspace(spec)
model = workspace.model()

# Get parameter order and initial values
parameter_order = model.config.par_order
initial_dict = dict(zip(parameter_order, model.config.suggested_init(), strict=False))
observation = workspace.data(model, include_auxdata=True)


# Define NLL
def nll(params, obs):
    parameter_vector = jnp.asarray([params[name] for name in parameter_order])
    logpdf = model.logpdf(parameter_vector, obs)
    return -2 * logpdf[0]


# Perform the fit
result = ew.fit(nll, sl.State.from_pytree(initial_dict), observation)

print(result.params)
#  {
#   'mu': Array(2.33333334, dtype=float64),
#   'norm1': Array(-3.37256584e-07, dtype=float64),
#   'norm2': Array(4.53219024e-07, dtype=float64),
#   'shape1': Array(-2.25047663e-08, dtype=float64),
# }
"""Standalone evermore counting experiment example."""

from __future__ import annotations

import typing as tp
from functools import partial

import evermore as evm
import jax
import jax.numpy as jnp
from flax import nnx
from jaxtyping import Array, Float, PyTree

import everwillow as ew
import everwillow.statelib as sl

jax.config.update("jax_enable_x64", True)


# type defs
Hist1D: tp.TypeAlias = Float[Array, " nbins"]
Args: tp.TypeAlias = tuple[
    nnx.GraphDef,  # graphdef
    nnx.State,  # state
    PyTree[Hist1D],  # hists
]


class Model(nnx.Module):
    def __init__(
        self,
        mu: evm.Parameter,
        norm1: evm.NormalParameter,
        norm2: evm.NormalParameter,
        shape: evm.NormalParameter,
    ):
        self.mu = mu
        self.norm1 = norm1
        self.norm2 = norm2
        self.shape = shape

    def __call__(self, hists: PyTree[Hist1D]) -> PyTree[Hist1D]:
        expectations = {}

        # signal process
        sig_mod = self.mu.scale()
        expectations["signal"] = sig_mod(hists["nominal"]["signal"])

        # bkg1 process
        bkg1_lnN = self.norm1.scale_log_asymmetric(up=jnp.array([1.1]), down=jnp.array([0.9]))
        bkg1_shape = self.shape.morphing(
            up_template=hists["shape_up"]["bkg1"],
            down_template=hists["shape_down"]["bkg1"],
        )
        # combine modifiers
        bkg1_mod = bkg1_lnN @ bkg1_shape
        expectations["bkg1"] = bkg1_mod(hists["nominal"]["bkg1"])

        # bkg2 process
        bkg2_lnN = self.norm2.scale_log_asymmetric(up=jnp.array([1.05]), down=jnp.array([0.95]))
        bkg2_shape = self.shape.morphing(
            up_template=hists["shape_up"]["bkg2"],
            down_template=hists["shape_down"]["bkg2"],
        )
        # combine modifiers
        bkg2_mod = bkg2_lnN @ bkg2_shape
        expectations["bkg2"] = bkg2_mod(hists["nominal"]["bkg2"])

        # return the modified expectations
        return expectations


hists = {
    "nominal": {
        "signal": jnp.array([3.0]),
        "bkg1": jnp.array([10.0]),
        "bkg2": jnp.array([20.0]),
    },
    "shape_up": {
        "bkg1": jnp.array([12.0]),
        "bkg2": jnp.array([23.0]),
    },
    "shape_down": {
        "bkg1": jnp.array([8.0]),
        "bkg2": jnp.array([19.0]),
    },
}


model = Model(
    mu=evm.Parameter(name="mu"),
    norm1=evm.NormalParameter(name="norm1"),
    norm2=evm.NormalParameter(name="norm2"),
    shape=evm.NormalParameter(name="shape"),
)

observation = jnp.array([37.0])
expectations = model(hists)


@nnx.jit
def loss(dynamic: nnx.State, observation: Hist1D, args: Args) -> Float[Array, ""]:
    # unpack
    (graphdef, static, hists) = args
    # reconstruct model
    model = nnx.merge(graphdef, dynamic, static)
    # calculate expectation
    expectations = model(hists)
    # calculate constraints
    constraints = evm.loss.get_log_probs(model)
    loss_val = evm.pdf.PoissonContinuous(evm.util.sum_over_leaves(expectations)).log_prob(observation).sum()
    # sum all up
    loss_val += evm.util.sum_over_leaves(constraints)
    return -jnp.sum(loss_val)


graphdef, dynamic, static = nnx.split(model, evm.filter.is_dynamic_parameter, ...)

# Perform the fit
result = ew.fit(
    partial(loss, args=(graphdef, static, hists)),
    sl.State.from_pytree(dynamic),
    observation,
)

print(result.params)
# {
#   'mu': Array(2.33333346, dtype=float64),
#   'norm1': Array(3.0642646e-08, dtype=float64),
#   'norm2': Array(2.33507612e-08, dtype=float64),
#   'shape': Array(-1.66275153e-08, dtype=float64),
# }
"""Standalone pyhs3 counting experiment example."""

from __future__ import annotations

import jax
import jax.numpy as jnp
import pyhs3
from pyhs3.data import PointData
from pyhs3.distributions import GaussianDist, PoissonDist, ProductDist
from pyhs3.functions import GenericFunction, InterpolationFunction
from pyhs3.metadata import Metadata
from pyhs3.parameter_points import ParameterPoint, ParameterSet
from pytensor.compile import mode
from pytensor.graph.basic import graph_inputs
from pytensor.graph.fg import FunctionGraph
from pytensor.link.jax.dispatch import jax_funcify

import everwillow as ew
import everwillow.statelib as sl

jax.config.update("jax_enable_x64", True)


def jaxify_distribution(model, distribution_name):
    """Convert a PyTensor distribution graph into a JAX-callable function."""
    distribution = model.distributions[distribution_name]
    inputs = [var for var in graph_inputs([distribution]) if var.name is not None]
    function_graph = FunctionGraph(inputs=inputs, outputs=[distribution], clone=True)
    mode.JAX.optimizer.rewrite(function_graph)
    return inputs, jax_funcify(function_graph)


# Build the workspace
workspace = pyhs3.Workspace(
    metadata=Metadata(hs3_version="0.2"),
    distributions=[
        PoissonDist(name="main_poisson", x="n_obs", mean="n_expected"),
        GaussianDist(name="norm1_constraint", x="a_norm1", mean="norm1", sigma=1.0),
        GaussianDist(name="norm2_constraint", x="a_norm2", mean="norm2", sigma=1.0),
        GaussianDist(name="shape1_constraint", x="a_shape1", mean="shape1", sigma=1.0),
        ProductDist(
            type="product_dist",
            name="model",
            factors=[
                "main_poisson",
                "norm1_constraint",
                "norm2_constraint",
                "shape1_constraint",
            ],
        ),
    ],
    functions=[
        GenericFunction(name="signal_expected", expression="mu * signal_nominal"),
        InterpolationFunction(
            name="bkg1_lnN_factor",
            nom="lnN_nom",
            high=["bkg1_lnN_up"],
            low=["bkg1_lnN_down"],
            vars=["norm1"],
            interpolationCodes=[1],
            positiveDefinite=False,
        ),
        InterpolationFunction(
            name="bkg1_shape_interp",
            nom="bkg1_nominal",
            high=["bkg1_shape_up"],
            low=["bkg1_shape_down"],
            vars=["shape1"],
            interpolationCodes=[0],
            positiveDefinite=False,
        ),
        GenericFunction(name="bkg1_expected", expression="bkg1_lnN_factor * bkg1_shape_interp"),
        InterpolationFunction(
            name="bkg2_lnN_factor",
            nom="lnN_nom",
            high=["bkg2_lnN_up"],
            low=["bkg2_lnN_down"],
            vars=["norm2"],
            interpolationCodes=[1],
            positiveDefinite=False,
        ),
        InterpolationFunction(
            name="bkg2_shape_interp",
            nom="bkg2_nominal",
            high=["bkg2_shape_up"],
            low=["bkg2_shape_down"],
            vars=["shape1"],
            interpolationCodes=[0],
            positiveDefinite=False,
        ),
        GenericFunction(name="bkg2_expected", expression="bkg2_lnN_factor * bkg2_shape_interp"),
        GenericFunction(
            name="n_expected",
            expression="signal_expected + bkg1_expected + bkg2_expected",
        ),
    ],
    parameter_points=[
        ParameterSet(
            name="default_values",
            parameters=[
                ParameterPoint(name="mu", value=1.0),
                ParameterPoint(name="norm1", value=0.0),
                ParameterPoint(name="norm2", value=0.0),
                ParameterPoint(name="shape1", value=0.0),
            ],
        )
    ],
    data=[
        PointData(name="n_obs", value=37.0),
        PointData(name="a_norm1", value=0.0),
        PointData(name="a_norm2", value=0.0),
        PointData(name="a_shape1", value=0.0),
        PointData(name="signal_nominal", value=3.0),
        PointData(name="bkg1_nominal", value=10.0),
        PointData(name="bkg1_shape_up", value=12.0),
        PointData(name="bkg1_shape_down", value=8.0),
        PointData(name="bkg2_nominal", value=20.0),
        PointData(name="bkg2_shape_up", value=23.0),
        PointData(name="bkg2_shape_down", value=19.0),
        PointData(name="lnN_nom", value=1.0),
        PointData(name="bkg1_lnN_up", value=1.1),
        PointData(name="bkg1_lnN_down", value=0.9),
        PointData(name="bkg2_lnN_up", value=1.05),
        PointData(name="bkg2_lnN_down", value=0.95),
    ],
)

# Extract model and convert to JAX
model = workspace.model()
inputs, jaxified = jaxify_distribution(model, "model")

# Build initial parameters and data
initial = {point.name: float(point.value) for point in workspace.parameter_points[0].parameters}
data_values = {point.name: float(point.value) for point in workspace.data}

# Separate observation from templates
observation = {
    "n_obs": data_values["n_obs"],
    "a_norm1": data_values["a_norm1"],
    "a_norm2": data_values["a_norm2"],
    "a_shape1": data_values["a_shape1"],
}
templates = {
    "signal_nominal": data_values["signal_nominal"],
    "bkg1_nominal": data_values["bkg1_nominal"],
    "bkg1_shape_up": data_values["bkg1_shape_up"],
    "bkg2_nominal": data_values["bkg2_nominal"],
    "bkg2_shape_up": data_values["bkg2_shape_up"],
}


# Define NLL
def nll(params, obs):
    merged = {**templates, **obs, **params}
    ordered = [merged[var.name] for var in inputs]
    probability = jaxified(*ordered)[0]
    return -jnp.log(jnp.asarray(probability))


# Perform the fit
result = ew.fit(nll, sl.State.from_pytree(initial), observation)


print(result.params)
# {
#   'mu': Array(2.33333374, dtype=float64),
#   'norm1': Array(-7.24415294e-09, dtype=float64),
#   'norm2': Array(-1.6095118e-08, dtype=float64),
#   'shape1': Array(-1.96874884e-07, dtype=float64),
# }