Transforms#
Parameter space transforms.
- class everwillow.parameters.transforms.TransformBase[source]#
Bases:
ABCAbstract base for parameter transformations.
Subclasses implement
unwrap(bounded → unconstrained) andwrap(unconstrained → bounded) using JAX-compatible array ops.
- class everwillow.parameters.transforms.MinuitTransform(lower, upper)[source]#
Bases:
TransformBaseMinuit-style transform for parameters with finite lower and upper bounds.
unwrapconverts a bounded value into an unconstrained internal representation, whilewrapinverts the mapping. Both bounds must be finite andlower < upper. .. rubric:: Example>>> transform = MinuitTransform(lower=0.0, upper=1.0) >>> jnp.isclose(transform.wrap(transform.unwrap(0.3)), 0.3) Array(True, dtype=bool)
Reference: https://root.cern.ch/download/minuit.pdf (Sec. 1.2.1).
- class everwillow.parameters.transforms.SigmoidTransform(lower, upper)[source]#
Bases:
TransformBaseLogit/sigmoid pair for parameters with finite lower and upper bounds.
unwrapapplies the logit of the affine-scaled value,wrapapplies the sigmoid. Bounds must be finite withlower < upper.Example
>>> transform = SigmoidTransform(lower=-2.0, upper=3.0) >>> value = -1.1 >>> jnp.isclose(transform.wrap(transform.unwrap(value)), value) Array(True, dtype=bool)
- class everwillow.parameters.transforms.OneSidedLogTransform(bound, direction)[source]#
Bases:
TransformBaseLog transform for parameters with exactly one finite bound.
Direction
"lower"enforcesvalue > bound(vialog(value - bound)), while direction"upper"enforcesvalue < bound(vialog(bound - value)).Example
>>> transform = OneSidedLogTransform(bound=0.0, direction="lower") >>> jnp.isclose(transform.wrap(transform.unwrap(2.0)), 2.0) Array(True, dtype=bool)
- class everwillow.parameters.transforms.SoftPlusTransform[source]#
Bases:
TransformBaseApplies the softplus transformation to parameters, projecting them from real space (R) to positive space (R+). This transformation is useful for enforcing the positivity of parameters and does not require lower or upper boundaries.
unwrapcomputes the inverse softplus (with validation), whilewrapappliesjax.nn.softplus.Example
>>> transform = SoftPlusTransform() >>> jnp.isclose(transform.wrap(transform.unwrap(0.8)), 0.8) Array(True, dtype=bool)