Source code for everwillow._src.parameters.transforms

"""
Abstract base and implementations of parameter space transforms.
"""

from __future__ import annotations

import abc
import dataclasses

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import ArrayLike

from everwillow._src.util import float_array

__all__ = [
    "MinuitTransform",
    "OneSidedLogTransform",
    "SigmoidTransform",
    "SoftPlusTransform",
    "TransformBase",
]


def _logit(x: ArrayLike) -> ArrayLike:
    """Compute ``log(x / (1 - x))`` in a numerically stable way."""
    return jnp.log(x) - jnp.log1p(-x)


def _sigmoid(x: ArrayLike) -> ArrayLike:
    """Compute ``1 / (1 + exp(-x))`` without overflow."""
    return jnp.where(
        x >= 0,
        1.0 / (1.0 + jnp.exp(-x)),
        jnp.exp(x) / (1.0 + jnp.exp(x)),
    )


[docs] @dataclasses.dataclass(frozen=True) class TransformBase(abc.ABC): """ Abstract base for parameter transformations. Subclasses implement ``unwrap`` (bounded → unconstrained) and ``wrap`` (unconstrained → bounded) using JAX-compatible array ops. """
[docs] @abc.abstractmethod def unwrap(self, value: ArrayLike) -> ArrayLike: """Transform a value from its constrained space to the real line."""
[docs] @abc.abstractmethod def wrap(self, value: ArrayLike) -> ArrayLike: """Transform a value from the real line back to its constrained space."""
[docs] @dataclasses.dataclass(frozen=True) class MinuitTransform(TransformBase): """ Minuit-style transform for parameters with finite lower and upper bounds. ``unwrap`` converts a bounded value into an unconstrained internal representation, while ``wrap`` inverts the mapping. Both bounds must be finite and ``lower < upper``. Example: >>> transform = MinuitTransform(lower=0.0, upper=1.0) >>> jnp.isclose(transform.wrap(transform.unwrap(0.3)), 0.3) Array(True, dtype=bool) Reference: https://root.cern.ch/download/minuit.pdf (Sec. 1.2.1). """ lower: ArrayLike upper: ArrayLike def __post_init__(self): if not jnp.isfinite(self.lower): msg = "lower bound must be finite." raise ValueError(msg) if not jnp.isfinite(self.upper): msg = "upper bound must be finite." raise ValueError(msg) if self.lower >= self.upper: msg = f"{self} requires lower bound to be strictly less than upper bound." raise ValueError(msg)
[docs] def unwrap(self, value: ArrayLike) -> ArrayLike: """Convert a bounded value into an unconstrained representation.""" value, lower, upper = map(float_array, (value, self.lower, self.upper)) value = eqx.error_if(value, ~jnp.isfinite(value), "value must be finite.") error_msg = f"value passed to {self} is exactly at or outside the boundaries [{self.lower}, {self.upper}]." value = eqx.error_if(value, value <= lower, error_msg) value = eqx.error_if(value, value >= upper, error_msg) # this formula turns user-provided "external" parameter values into "internal" values return jnp.arcsin(2.0 * (value - lower) / (upper - lower) - 1.0)
[docs] def wrap(self, value: ArrayLike) -> ArrayLike: """Convert an unconstrained value back into the bounded interval.""" value, lower, upper = map(float_array, (value, self.lower, self.upper)) return lower + (upper - lower) / 2 * (jnp.sin(value) + 1)
[docs] @dataclasses.dataclass(frozen=True) class SigmoidTransform(TransformBase): """ Logit/sigmoid pair for parameters with finite lower and upper bounds. ``unwrap`` applies the logit of the affine-scaled value, ``wrap`` applies the sigmoid. Bounds must be finite with ``lower < upper``. Example: >>> transform = SigmoidTransform(lower=-2.0, upper=3.0) >>> value = -1.1 >>> jnp.isclose(transform.wrap(transform.unwrap(value)), value) Array(True, dtype=bool) """ lower: ArrayLike upper: ArrayLike # check for finite boundaries def __post_init__(self): if not jnp.isfinite(self.lower): msg = "lower bound must be finite." raise ValueError(msg) if not jnp.isfinite(self.upper): msg = "upper bound must be finite." raise ValueError(msg) if self.lower >= self.upper: msg = f"{self} requires lower bound to be strictly less than upper bound." raise ValueError(msg)
[docs] def unwrap(self, value: ArrayLike) -> ArrayLike: """Convert a bounded value into an unconstrained representation.""" value, lower, upper = map(float_array, (value, self.lower, self.upper)) value = eqx.error_if(value, ~jnp.isfinite(value), "value must be finite.") error_msg = f"value passed to {self} is exactly at or outside the boundaries [{self.lower}, {self.upper}]." value = eqx.error_if(value, value <= lower, error_msg) value = eqx.error_if(value, value >= upper, error_msg) # this formula turns user-provided "external" parameter values into "internal" values return _logit((value - lower) / (upper - lower))
[docs] def wrap(self, value: ArrayLike) -> ArrayLike: """Convert an unconstrained value back into the bounded interval.""" value, lower, upper = map(float_array, (value, self.lower, self.upper)) return lower + (upper - lower) * _sigmoid(value)
[docs] @dataclasses.dataclass(frozen=True) class OneSidedLogTransform(TransformBase): """ Log transform for parameters with exactly one finite bound. Direction ``"lower"`` enforces ``value > bound`` (via ``log(value - bound)``), while direction ``"upper"`` enforces ``value < bound`` (via ``log(bound - value)``). Example: >>> transform = OneSidedLogTransform(bound=0.0, direction="lower") >>> jnp.isclose(transform.wrap(transform.unwrap(2.0)), 2.0) Array(True, dtype=bool) """ bound: ArrayLike direction: str = eqx.field(static=True) # 'lower' or 'upper' # check for finite lower boundary def __post_init__(self): if self.direction not in ("lower", "upper"): message = f"unsupported direction {self.direction!r} for {self}." raise ValueError(message) if not jnp.isfinite(self.bound): msg = "bound must be finite." raise ValueError(msg)
[docs] def unwrap(self, value: ArrayLike) -> ArrayLike: """Convert a single-sided bounded value into an unconstrained representation.""" value, bound = map(float_array, (value, self.bound)) value = eqx.error_if(value, ~jnp.isfinite(value), "value must be finite.") if self.direction == "lower": error_msg = f"value passed to {self} must be greater than lower bound {self.bound}." value = eqx.error_if(value, value <= bound, error_msg) return jnp.log(value - bound) error_msg = f"value passed to {self} must be less than upper bound {self.bound}." value = eqx.error_if(value, value >= bound, error_msg) return jnp.log(bound - value)
[docs] def wrap(self, value: ArrayLike) -> ArrayLike: """Convert an unconstrained value back into the one-sided bounded space.""" value, bound = map(float_array, (value, self.bound)) if self.direction == "lower": return bound + jnp.exp(value) return bound - jnp.exp(value)
[docs] @dataclasses.dataclass(frozen=True) class SoftPlusTransform(TransformBase): """ Applies the softplus transformation to parameters, projecting them from real space (R) to positive space (R+). This transformation is useful for enforcing the positivity of parameters and does not require lower or upper boundaries. ``unwrap`` computes the inverse softplus (with validation), while ``wrap`` applies ``jax.nn.softplus``. Example: >>> transform = SoftPlusTransform() >>> jnp.isclose(transform.wrap(transform.unwrap(0.8)), 0.8) Array(True, dtype=bool) """
[docs] def unwrap(self, value: ArrayLike) -> ArrayLike: """Apply the inverse softplus, validating positivity and finiteness.""" value = float_array(value) value = eqx.error_if(value, ~jnp.isfinite(value), "value must be finite.") value = eqx.error_if(value, value <= 0, f"expected positive inputs to {self}.") return jnp.log(-jnp.expm1(-value)) + value
[docs] def wrap(self, value: ArrayLike) -> ArrayLike: """Apply the softplus function.""" value = float_array(value) return jax.nn.softplus(value)