State#
Immutable mapping helpers for working with JAX pytrees.
- class everwillow.statelib.state.State(mapping, *, treedefmeta)[source]#
Bases:
BaseMapping[V]Container that stores flattened pytrees keyed by canonical keys.
The state keeps track of the pytree definition so it can be converted back to the original nested structure.
Examples
>>> state = State.from_pytree({"a": {"b": 2.0}}) >>> state["a.b"] 2.0 >>> state.to_pytree() {'a': {'b': 2.0}}
- classmethod from_pytree(pytree, *, is_leaf=None, sep='.')[source]#
Build a
Stateinstance from an arbitrary pytree.- Parameters:
pytree (PyTree[~V]) – Nested structure supported by
jax.tree_util.is_leaf (Callable[[V], bool] | None) – Optional callable passed to
jax.tree_util.tree_flatten_with_path()to customize which nodes are treated as leaves.sep (str | None) – Separator used to join key entries when constructing public keys. Defaults to
".". WhenNone, keys are returned as tuples.
- Returns:
New
Staterepresentingpytree.- Return type:
State[V]
Examples
>>> State.from_pytree({"a": [1, 2]}).mapping mappingproxy({'a.0': 1, 'a.1': 2})
- everwillow.statelib.state.canonicalize_key(path: tuple[Any, ...]) str[source]#
- everwillow.statelib.state.canonicalize_key(path: tuple[Any, ...], *, sep: str) str
- everwillow.statelib.state.canonicalize_key(path: tuple[Any, ...], *, sep: None) tuple[str | int, ...]
Convert a JAX key path to plain Python entries.
- Parameters:
path (tuple[Any, ...]) – Key path emitted by
jax.tree_util.tree_flatten_with_path().sep (str | None) – Separator used to join the entries into a string. Defaults to
".". WhenNonethe key is returned as a tuple.
- Returns:
Canonical key representation that can be used to index a
State.- Raises:
TypeError – If
pathcontains an unsupported key type.ValueError – If a key segment contains the separator string.
- Return type:
Examples
>>> import jax.tree_util as jtu >>> canonicalize_key((jtu.DictKey("a"), jtu.SequenceKey(0))) 'a.0'
Use
sep=Nonefor tuple keys:>>> canonicalize_key((jtu.DictKey("a"), jtu.SequenceKey(0)), sep=None) ('a', 0)
- everwillow.statelib.state.combine_partitions(left, right)[source]#
Merge two partitions that originated from the same state.
- Parameters:
left (State[V]) – First partition returned by
partition().right (State[V]) – Second partition returned by
partition().
- Returns:
Statecontaining the union of both partitions.- Raises:
ValueError – If the partitions do not share the same treedefmeta.
- Return type:
State[V]
Examples
>>> state = State.from_pytree({"a": 1, "b": 2}) >>> left, right = partition(state, predicate=lambda key, _: key == "a") >>> combine_partitions(left, right)["b"] 2
- everwillow.statelib.state.merge(*states)[source]#
Combine several States into one.
When states share overlapping keys, the last value wins.
- Parameters:
*states (State[V]) – Sequence of
Stateinstances to merge (at least one).- Returns:
New
Statecontaining all key/value pairs from the inputs.- Raises:
ValueError – If no states are provided.
- Return type:
State[V]
- everwillow.statelib.state.partition(state, *, predicate)[source]#
Split a state into two partitions based on a predicate.
- Parameters:
- Returns:
Tuple
(left, right)containing twoStatepartitioned from the original state. Elements not satisfying the predicate are set toNoneinleftand vice versa forright.- Return type:
Examples
>>> state = State.from_pytree({"a": 1, "b": 2}) >>> left, right = partition(state, predicate=lambda key, _: key == "a") >>> right.notnone {'b': 2}
- everwillow.statelib.state.split(state)[source]#
Split a merged State back into original States.
For overlapping keys, all returned segments receive the merged value.
- everwillow.statelib.state.update(state, *, updates)[source]#
Return a new state with specific entries replaced.
- Parameters:
- Returns:
New
Statewith the replacements applied.- Raises:
KeyError – If
updatesincludes a key that is not present instate.- Return type:
State[V]
Examples
>>> base = State.from_pytree({"a": 1, "b": 2}) >>> update(base, updates={"b": 99}).to_dict() {'a': 1, 'b': 99}