State#

Immutable mapping helpers for working with JAX pytrees.

class everwillow.statelib.state.State(mapping, *, treedefmeta)[source]#

Bases: BaseMapping[V]

Container that stores flattened pytrees keyed by canonical keys.

The state keeps track of the pytree definition so it can be converted back to the original nested structure.

Examples

>>> state = State.from_pytree({"a": {"b": 2.0}})
>>> state["a.b"]
2.0
>>> state.to_pytree()
{'a': {'b': 2.0}}
classmethod from_pytree(pytree, *, is_leaf=None, sep='.')[source]#

Build a State instance from an arbitrary pytree.

Parameters:
  • pytree (PyTree[~V]) – Nested structure supported by jax.tree_util.

  • is_leaf (Callable[[V], bool] | None) – Optional callable passed to jax.tree_util.tree_flatten_with_path() to customize which nodes are treated as leaves.

  • sep (str | None) – Separator used to join key entries when constructing public keys. Defaults to ".". When None, keys are returned as tuples.

Returns:

New State representing pytree.

Return type:

State[V]

Examples

>>> State.from_pytree({"a": [1, 2]}).mapping
mappingproxy({'a.0': 1, 'a.1': 2})
to_pytree()[source]#

Reconstruct the stored pytree using the cached tree definition.

Returns:

Pytree with the same structure used to create the state.

Return type:

PyTree[~V]

Examples

>>> state = State.from_pytree({"x": 1})
>>> state.to_pytree()
{'x': 1}
show()[source]#

Pretty-print this State with rich array visualization.

tree_flatten_with_keys()[source]#
classmethod tree_unflatten(aux_data, children)[source]#
Return type:

State[V]

everwillow.statelib.state.canonicalize_key(path: tuple[Any, ...]) str[source]#
everwillow.statelib.state.canonicalize_key(path: tuple[Any, ...], *, sep: str) str
everwillow.statelib.state.canonicalize_key(path: tuple[Any, ...], *, sep: None) tuple[str | int, ...]

Convert a JAX key path to plain Python entries.

Parameters:
Returns:

Canonical key representation that can be used to index a State.

Raises:
  • TypeError – If path contains an unsupported key type.

  • ValueError – If a key segment contains the separator string.

Return type:

str | tuple[str | int, …]

Examples

>>> import jax.tree_util as jtu
>>> canonicalize_key((jtu.DictKey("a"), jtu.SequenceKey(0)))
'a.0'

Use sep=None for tuple keys:

>>> canonicalize_key((jtu.DictKey("a"), jtu.SequenceKey(0)), sep=None)
('a', 0)
everwillow.statelib.state.combine_partitions(left, right)[source]#

Merge two partitions that originated from the same state.

Parameters:
Returns:

State containing the union of both partitions.

Raises:

ValueError – If the partitions do not share the same treedefmeta.

Return type:

State[V]

Examples

>>> state = State.from_pytree({"a": 1, "b": 2})
>>> left, right = partition(state, predicate=lambda key, _: key == "a")
>>> combine_partitions(left, right)["b"]
2
everwillow.statelib.state.merge(*states)[source]#

Combine several States into one.

When states share overlapping keys, the last value wins.

Parameters:

*states (State[V]) – Sequence of State instances to merge (at least one).

Returns:

New State containing all key/value pairs from the inputs.

Raises:

ValueError – If no states are provided.

Return type:

State[V]

everwillow.statelib.state.partition(state, *, predicate)[source]#

Split a state into two partitions based on a predicate.

Parameters:
  • state (State[V]) – State instance to partition.

  • predicate (Callable[[str | tuple[str | int, ...], V], bool]) – Callable returning True for items that should go into the first partition.

Returns:

Tuple (left, right) containing two State partitioned from the original state. Elements not satisfying the predicate are set to None in left and vice versa for right.

Return type:

tuple[State[V], State[V]]

Examples

>>> state = State.from_pytree({"a": 1, "b": 2})
>>> left, right = partition(state, predicate=lambda key, _: key == "a")
>>> right.notnone
{'b': 2}
everwillow.statelib.state.split(state)[source]#

Split a merged State back into original States.

For overlapping keys, all returned segments receive the merged value.

Parameters:

state (State[V]) – State instance created by merge().

Returns:

Tuple of State instances corresponding to the original inputs used to create state.

Raises:

ValueError – If state was not produced by merge().

Return type:

tuple[State[V], …]

everwillow.statelib.state.update(state, *, updates)[source]#

Return a new state with specific entries replaced.

Parameters:
  • state (State[V]) – Original State to copy.

  • updates (Mapping[str | tuple[str | int, ...], V | EllipsisType]) – Mapping of existing keys to replacement values. Entries whose value is Ellipsis are ignored, which makes it easy to reuse the same dictionary across multiple updates.

Returns:

New State with the replacements applied.

Raises:

KeyError – If updates includes a key that is not present in state.

Return type:

State[V]

Examples

>>> base = State.from_pytree({"a": 1, "b": 2})
>>> update(base, updates={"b": 99}).to_dict()
{'a': 1, 'b': 99}