Source code for everwillow._src.parameters.transforms
"""
Abstract base and implementations of parameter space transforms.
"""
from __future__ import annotations
import abc
import dataclasses
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import ArrayLike
from everwillow._src.util import float_array
__all__ = [
"MinuitTransform",
"OneSidedLogTransform",
"SigmoidTransform",
"SoftPlusTransform",
"TransformBase",
]
def _logit(x: ArrayLike) -> ArrayLike:
"""Compute ``log(x / (1 - x))`` in a numerically stable way."""
return jnp.log(x) - jnp.log1p(-x)
def _sigmoid(x: ArrayLike) -> ArrayLike:
"""Compute ``1 / (1 + exp(-x))`` without overflow."""
return jnp.where(
x >= 0,
1.0 / (1.0 + jnp.exp(-x)),
jnp.exp(x) / (1.0 + jnp.exp(x)),
)
[docs]
@dataclasses.dataclass(frozen=True)
class TransformBase(abc.ABC):
"""
Abstract base for parameter transformations.
Subclasses implement ``unwrap`` (bounded → unconstrained) and ``wrap``
(unconstrained → bounded) using JAX-compatible array ops.
"""
[docs]
@abc.abstractmethod
def unwrap(self, value: ArrayLike) -> ArrayLike:
"""Transform a value from its constrained space to the real line."""
[docs]
@abc.abstractmethod
def wrap(self, value: ArrayLike) -> ArrayLike:
"""Transform a value from the real line back to its constrained space."""
[docs]
@dataclasses.dataclass(frozen=True)
class MinuitTransform(TransformBase):
"""
Minuit-style transform for parameters with finite lower and upper bounds.
``unwrap`` converts a bounded value into an unconstrained internal representation,
while ``wrap`` inverts the mapping. Both bounds must be finite and ``lower < upper``.
Example:
>>> transform = MinuitTransform(lower=0.0, upper=1.0)
>>> jnp.isclose(transform.wrap(transform.unwrap(0.3)), 0.3)
Array(True, dtype=bool)
Reference: https://root.cern.ch/download/minuit.pdf (Sec. 1.2.1).
"""
lower: ArrayLike
upper: ArrayLike
def __post_init__(self):
if not jnp.isfinite(self.lower):
msg = "lower bound must be finite."
raise ValueError(msg)
if not jnp.isfinite(self.upper):
msg = "upper bound must be finite."
raise ValueError(msg)
if self.lower >= self.upper:
msg = f"{self} requires lower bound to be strictly less than upper bound."
raise ValueError(msg)
[docs]
def unwrap(self, value: ArrayLike) -> ArrayLike:
"""Convert a bounded value into an unconstrained representation."""
value, lower, upper = map(float_array, (value, self.lower, self.upper))
value = eqx.error_if(value, ~jnp.isfinite(value), "value must be finite.")
error_msg = f"value passed to {self} is exactly at or outside the boundaries [{self.lower}, {self.upper}]."
value = eqx.error_if(value, value <= lower, error_msg)
value = eqx.error_if(value, value >= upper, error_msg)
# this formula turns user-provided "external" parameter values into "internal" values
return jnp.arcsin(2.0 * (value - lower) / (upper - lower) - 1.0)
[docs]
def wrap(self, value: ArrayLike) -> ArrayLike:
"""Convert an unconstrained value back into the bounded interval."""
value, lower, upper = map(float_array, (value, self.lower, self.upper))
return lower + (upper - lower) / 2 * (jnp.sin(value) + 1)
[docs]
@dataclasses.dataclass(frozen=True)
class SigmoidTransform(TransformBase):
"""
Logit/sigmoid pair for parameters with finite lower and upper bounds.
``unwrap`` applies the logit of the affine-scaled value, ``wrap`` applies the sigmoid.
Bounds must be finite with ``lower < upper``.
Example:
>>> transform = SigmoidTransform(lower=-2.0, upper=3.0)
>>> value = -1.1
>>> jnp.isclose(transform.wrap(transform.unwrap(value)), value)
Array(True, dtype=bool)
"""
lower: ArrayLike
upper: ArrayLike
# check for finite boundaries
def __post_init__(self):
if not jnp.isfinite(self.lower):
msg = "lower bound must be finite."
raise ValueError(msg)
if not jnp.isfinite(self.upper):
msg = "upper bound must be finite."
raise ValueError(msg)
if self.lower >= self.upper:
msg = f"{self} requires lower bound to be strictly less than upper bound."
raise ValueError(msg)
[docs]
def unwrap(self, value: ArrayLike) -> ArrayLike:
"""Convert a bounded value into an unconstrained representation."""
value, lower, upper = map(float_array, (value, self.lower, self.upper))
value = eqx.error_if(value, ~jnp.isfinite(value), "value must be finite.")
error_msg = f"value passed to {self} is exactly at or outside the boundaries [{self.lower}, {self.upper}]."
value = eqx.error_if(value, value <= lower, error_msg)
value = eqx.error_if(value, value >= upper, error_msg)
# this formula turns user-provided "external" parameter values into "internal" values
return _logit((value - lower) / (upper - lower))
[docs]
def wrap(self, value: ArrayLike) -> ArrayLike:
"""Convert an unconstrained value back into the bounded interval."""
value, lower, upper = map(float_array, (value, self.lower, self.upper))
return lower + (upper - lower) * _sigmoid(value)
[docs]
@dataclasses.dataclass(frozen=True)
class OneSidedLogTransform(TransformBase):
"""
Log transform for parameters with exactly one finite bound.
Direction ``"lower"`` enforces ``value > bound`` (via ``log(value - bound)``),
while direction ``"upper"`` enforces ``value < bound`` (via ``log(bound - value)``).
Example:
>>> transform = OneSidedLogTransform(bound=0.0, direction="lower")
>>> jnp.isclose(transform.wrap(transform.unwrap(2.0)), 2.0)
Array(True, dtype=bool)
"""
bound: ArrayLike
direction: str = eqx.field(static=True) # 'lower' or 'upper'
# check for finite lower boundary
def __post_init__(self):
if self.direction not in ("lower", "upper"):
message = f"unsupported direction {self.direction!r} for {self}."
raise ValueError(message)
if not jnp.isfinite(self.bound):
msg = "bound must be finite."
raise ValueError(msg)
[docs]
def unwrap(self, value: ArrayLike) -> ArrayLike:
"""Convert a single-sided bounded value into an unconstrained representation."""
value, bound = map(float_array, (value, self.bound))
value = eqx.error_if(value, ~jnp.isfinite(value), "value must be finite.")
if self.direction == "lower":
error_msg = f"value passed to {self} must be greater than lower bound {self.bound}."
value = eqx.error_if(value, value <= bound, error_msg)
return jnp.log(value - bound)
error_msg = f"value passed to {self} must be less than upper bound {self.bound}."
value = eqx.error_if(value, value >= bound, error_msg)
return jnp.log(bound - value)
[docs]
def wrap(self, value: ArrayLike) -> ArrayLike:
"""Convert an unconstrained value back into the one-sided bounded space."""
value, bound = map(float_array, (value, self.bound))
if self.direction == "lower":
return bound + jnp.exp(value)
return bound - jnp.exp(value)
[docs]
@dataclasses.dataclass(frozen=True)
class SoftPlusTransform(TransformBase):
"""
Applies the softplus transformation to parameters, projecting them from
real space (R) to positive space (R+). This transformation is useful for
enforcing the positivity of parameters and does not require lower or
upper boundaries.
``unwrap`` computes the inverse softplus (with validation), while ``wrap`` applies
``jax.nn.softplus``.
Example:
>>> transform = SoftPlusTransform()
>>> jnp.isclose(transform.wrap(transform.unwrap(0.8)), 0.8)
Array(True, dtype=bool)
"""
[docs]
def unwrap(self, value: ArrayLike) -> ArrayLike:
"""Apply the inverse softplus, validating positivity and finiteness."""
value = float_array(value)
value = eqx.error_if(value, ~jnp.isfinite(value), "value must be finite.")
value = eqx.error_if(value, value <= 0, f"expected positive inputs to {self}.")
return jnp.log(-jnp.expm1(-value)) + value
[docs]
def wrap(self, value: ArrayLike) -> ArrayLike:
"""Apply the softplus function."""
value = float_array(value)
return jax.nn.softplus(value)