Parameter Management#

Everwillow’s statelib module provides immutable containers for managing model parameters. You wrap parameter dicts in a State to pass them to fit(), uncertainties(), and other APIs. The State flattens nested structures into a flat mapping with dot-separated string keys, tracks the original structure for round-tripping, and provides utilities for updating and partitioning parameters.

Creating and inspecting a State#

import everwillow.statelib as sl

# Create from any dict (or nested dict)
state = sl.State.from_pytree({"a": 1.0, "b": {"c": 2.0, "d": 3.0}})

# Flat view with string keys
print(state.to_dict())
# {'a': 1.0, 'b.c': 2.0, 'b.d': 3.0}

# Access by string key
print(state["b.c"])  # 2.0

# Round-trip back to original structure
print(state.to_pytree())
# {'a': 1.0, 'b': {'c': 2.0, 'd': 3.0}}

The default separator is ".". Override it with sep= or use sep=None for tuple keys:

state_tuple = sl.State.from_pytree({"a": {"b": 1.0}}, sep=None)
print(list(state_tuple.keys()))  # [('a', 'b')]

state_slash = sl.State.from_pytree({"a": {"b": 1.0}}, sep="/")
print(list(state_slash.keys()))  # ['a/b']

Updating parameters#

update() returns a new State with specific values replaced - the original is unchanged:

modified = sl.update(state, updates={"a": 99.0})
print(modified["a"])  # 99.0
print(state["a"])  # 1.0 (original unchanged)

Partitioning and recombining#

Split a State into two groups using a predicate. Excluded entries become None; use .notnone to see only the active ones:

left, right = sl.partition(
    state,
    predicate=lambda key, _: key.startswith("b"),
)

print(left.notnone)
# {'b.c': 2.0, 'b.d': 3.0}

print(right.notnone)
# {'a': 1.0}

# Recombine into the original state
restored = sl.combine_partitions(left, right)
assert restored.to_pytree() == state.to_pytree()

Merging and splitting#

merge() combines multiple states into one. split() reverses it:

state_x = sl.State.from_pytree({"x": 1.0})
state_y = sl.State.from_pytree({"y": 2.0, "z": 3.0})

merged = sl.merge(state_x, state_y)
print(merged.to_dict())
# {'x': 1.0, 'y': 2.0, 'z': 3.0}

# to_pytree() returns a tuple  - one entry per original state
print(merged.to_pytree())
# ({'x': 1.0}, {'y': 2.0, 'z': 3.0})

# split() recovers individual states
part_x, part_y = sl.split(merged)
print(part_x.to_pytree())  # {'x': 1.0}
print(part_y.to_pytree())  # {'y': 2.0, 'z': 3.0}

When states share overlapping keys (e.g. after renaming), the merged state holds a single value and to_pytree() propagates it to each sub-pytree. See Combining Models for the full pattern.

Renaming keys#

apply_transformations() renames keys in a state. The original treedef is preserved, so to_pytree() still produces the original structure:

state = sl.State.from_pytree({"m": 125.0, "scale": 1.0})
renamed = sl.apply_transformations(state, {"m": sl.Transform(new_key="mass")})

print(renamed.to_dict())
# {'mass': 125.0, 'scale': 1.0}

# to_pytree() still uses the original key
print(renamed.to_pytree())
# {'m': 125.0, 'scale': 1.0}

Transform also accepts a value_fn to transform the value at the same time:

scaled = sl.apply_transformations(
    state,
    {"m": sl.Transform(new_key="mass_gev", value_fn=lambda _k, v: v * 1000)},
)
print(scaled["mass_gev"])  # 125000.0