Source code for everwillow._src.parameters.bounds
"""
Example
-------
>>> import jax.numpy as jnp
>>> from everwillow import statelib as sl
>>> from everwillow.parameters import transforms
>>> state = sl.State.from_pytree({"mu": 0.3})
>>> transform_map = {"mu": transforms.MinuitTransform(lower=0.0, upper=1.0)}
>>> unwrapped = unwrap(state, transform_map)
>>> jnp.isclose(wrap(unwrapped, transform_map)["mu"], state["mu"])
Array(True, dtype=bool)
"""
from __future__ import annotations
__all__ = ["unwrap", "wrap"]
import typing as tp
from everwillow._src.parameters.transforms import TransformBase
from everwillow._src.statelib import K, State, V
[docs]
def unwrap(
state: State[V],
transform_mapping: tp.Mapping[K, TransformBase],
) -> State[V]:
"""Transform parameter values from bounded to unconstrained space.
Applies each transform's ``unwrap`` method to the corresponding
parameter in ``state``. Parameters not in ``transform_mapping``
are left unchanged.
Args:
state: Parameter state with bounded values.
transform_mapping: Maps parameter keys to their transforms.
Returns:
New state with unconstrained (internal) parameter values.
Raises:
KeyError: If ``transform_mapping`` contains keys not in ``state``.
"""
if not transform_mapping:
return state
if missing := set(transform_mapping.keys()) - set(state.mapping.keys()):
msg = f"Transform mapping contains keys not in state: {missing}"
raise KeyError(msg)
new_mapping = dict(state.mapping)
for key, transform in transform_mapping.items():
new_mapping[key] = transform.unwrap(new_mapping[key])
return State(new_mapping, treedefmeta=state.treedefmeta)
[docs]
def wrap(
state: State[V],
transform_mapping: tp.Mapping[K, TransformBase],
) -> State[V]:
"""Transform parameter values from unconstrained back to bounded space.
Applies each transform's ``wrap`` method to the corresponding
parameter in ``state``. Parameters not in ``transform_mapping``
are left unchanged.
Args:
state: Parameter state with unconstrained (internal) values.
transform_mapping: Maps parameter keys to their transforms.
Returns:
New state with bounded (external) parameter values.
Raises:
KeyError: If ``transform_mapping`` contains keys not in ``state``.
"""
if not transform_mapping:
return state
if missing := set(transform_mapping) - set(state.mapping):
msg = f"Transform mapping contains keys not in state: {missing}"
raise KeyError(msg)
new_mapping = dict(state.mapping)
for key, transform in transform_mapping.items():
new_mapping[key] = transform.wrap(new_mapping[key])
return State(new_mapping, treedefmeta=state.treedefmeta)