Source code for everwillow._src.statelib.transform

"""State transformation helpers for :mod:`everwillow._src.statelib.state`."""

from __future__ import annotations

__all__ = ["Transform", "apply_transformations"]

import dataclasses
import typing as tp

from everwillow._src.statelib.state import K, State, V


def _identity(key: K, value: V) -> V:
    """Return the value unchanged.

    Args:
        key: Original key for the value (unused).
        value: Value associated with ``key``.

    Returns:
        The unmodified ``value``.
    """
    del key  # unused in default identity
    return value


[docs] @dataclasses.dataclass(frozen=True) class Transform(tp.Generic[V]): """Describe how a single key/value pair should be rewritten. Examples: >>> transform = Transform(new_key="scale", value_fn=lambda _k, v: 2 * v) >>> transform.new_key 'scale' """ new_key: K #: Replacement canonical key used in the transformed state. value_fn: tp.Callable[[K, V], V] = dataclasses.field( default=_identity ) #: Callable applied to derive the transformed value.
[docs] def apply_transformations( state: State[V], transformations: tp.Mapping[K, Transform[V]], ) -> State[V]: """Rewrite selected entries in a ``State``. Args: state: Original state whose segments will be updated. transformations: Mapping from existing keys to ``Transform`` objects. Returns: New ``State`` instance containing the transformed key/value pairs. Raises: TypeError: If ``state`` is not a ``State`` instance. KeyError: If transformations reference keys not present in ``state``. ValueError: If multiple transformations target the same destination key. Examples: >>> base = State.from_pytree({"a": 1, "b": 2}) >>> transform = {"a": Transform(new_key="alpha")} >>> apply_transformations(base, transform).to_dict() {'alpha': 1, 'b': 2} """ if not isinstance(state, State): message = "'state' must be a State instance" # type: ignore[unreachable] raise TypeError(message) if len(transformations) == 0: return state if missing_keys := set(transformations) - set(state.keys()): msg = f"transformations reference keys not present in state: {missing_keys}" raise KeyError(msg) new_data, new_keys = {}, [] for key, value in state.items(): if key in transformations: transform = transformations[key] new_key = transform.new_key new_value = transform.value_fn(key, value) if new_key in new_data: msg = f"multiple transformations target the same key: {new_key}" raise ValueError(msg) new_data[new_key] = new_value new_keys.append(new_key) else: if key in new_data: msg = f"transformation produced duplicate key: {key}" raise ValueError(msg) new_data[key] = value new_keys.append(key) # Update treedefmeta to reflect new keys, returns a new instance treedefmeta = dataclasses.replace(state.treedefmeta, keys=tuple(new_keys)) return State(mapping=new_data, treedefmeta=treedefmeta)