evofabric.core.graph._state 源代码

# -*- coding: utf-8 -*-
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.

import copy
from functools import lru_cache
from typing import (
    Any, Callable, ClassVar, Dict, get_args, get_origin, List, Optional, Tuple, Union
)

from pydantic import create_model, Field, SkipValidation, TypeAdapter
from typing_extensions import Annotated

from ._state_update import StateUpdater
from ._utils import _make_class_name
from ..factory import (
    BaseComponent, fill_defaults, is_basemodel, is_typeddict, safe_convert_to_schema, safe_get_attr,
    safe_set_attr, StateSchemaSerializable
)
from ..typing import MISSING, State, StateDelta, StateMessage, StateSchema


@lru_cache
def get_update_function(state_schema):
    def _walk_type(tp, factory, prefix: str, out: dict[str, [type, Callable]]):
        origin = get_origin(tp)
        if origin is Annotated:
            typ, method_name, *_ = get_args(tp)
            if isinstance(method_name, str):
                out[prefix] = [typ, factory(method_name)]
            return
        if origin is dict or origin is Dict:
            args = get_args(tp)
            if len(args) == 2:  # Dict[K, V]
                _walk_type(args[1], factory, prefix + ".*", out)
            return
        if is_typeddict(tp):
            for k, v in tp.__annotations__.items():
                _walk_type(v, factory, f"{prefix}.{k}", out)
            return

    updaters: dict[str, [type, Callable]] = {}
    for base in reversed(state_schema.__mro__):
        if not hasattr(base, "__annotations__"):
            continue
        for name, anno in base.__annotations__.items():
            if get_origin(anno) is ClassVar or name.endswith("_"):
                continue
            _walk_type(anno, StateUpdater.get, name, updaters)
    return updaters


[文档] def generate_state_schema( variables: Optional[List[Tuple[str, Any, str]]] = None ): """ Declare the variable names and types for the state information transmitted in the workflow. Notes: * Variable names must follow Python's variable naming conventions. * Variable types must be one of: str, int, float, list, tuple, dict. * A constant variable named `messages` exists to record agent context; avoid assigning a variable with the same name. Args: variables: A list of tuples, each containing a variable name, its type and update strategy. Examples: declare_state_variables([("msg_id", "str", "overwrite"), ("user_id", bool, "overwrite")]) Raises: ValueError: If a variable type is invalid. ValueError: If a variable name conflicts with a reserved field or duplicates an existing name. ValueError: If an update strategy is not registered in StateUpdater """ variables = variables or [] fields = { "messages": (Annotated[list[StateMessage], "append_messages"], ...) } for name, type_str, update_strategy in variables: if name in fields: raise ValueError(f"Got an repeat name: {name}") if not StateUpdater.registered(update_strategy): raise ValueError(f"Got an unregistered update strategy: {update_strategy}") try: if isinstance(type_str, str): type_str = eval(type_str) annotated_type = Annotated[type_str, update_strategy] except Exception as e: raise ValueError(f"Cannot convert input to a valid annotated, got {type_str}") from e fields[name] = (annotated_type, ...) state_schema = create_model( _make_class_name("StateSchema"), **fields ) return state_schema
[文档] class StateCkpt(BaseComponent, StateSchemaSerializable): delta: Optional[SkipValidation[StateDelta]] = None parent: Optional['StateCkpt'] = None state_schema: Optional[type[StateSchema]] = None materialized_state_cache: Optional[SkipValidation[State]] = Field(default=None, init=False, repr=False)
[文档] @staticmethod def merge_state(state, delta, state_schema) -> Union[Dict, StateSchema]: state = copy.deepcopy(state) updater = get_update_function(state_schema) for attr, (typ, update_func) in updater.items(): safe_set_attr( state, attr, TypeAdapter(typ).validate_python(update_func( safe_get_attr(state, attr, MISSING), safe_get_attr(delta, attr, MISSING))), ) if is_basemodel(state_schema) and not isinstance(state, state_schema): state = safe_convert_to_schema(state, state_schema) elif is_typeddict(state_schema): state = state_schema(**state) return state
[文档] def materialize(self) -> Union[State, StateSchema]: if self.materialized_state_cache: return copy.deepcopy(self.materialized_state_cache) state = self.parent.materialize() if self.parent else fill_defaults(self.state_schema) updated = self.merge_state(state, self.delta, self.state_schema) self.materialized_state_cache = copy.deepcopy(updated) return updated
[文档] @classmethod def merge( cls, checkpoints: List['StateCkpt'], strategy: Callable[[List[State]], State] = None ): if not checkpoints: raise ValueError("Cannot merge an empty list of checkpoints.") parent_states = [ckpt.materialize() for ckpt in checkpoints] if strategy: merged_state = strategy(parent_states) else: # use update strategy to merge merged_state = parent_states[0] for state in parent_states[1:]: merged_state = cls.merge_state(merged_state, state, checkpoints[0].state_schema) return cls( delta=merged_state, parent=None, state_schema=checkpoints[0].state_schema, )
[文档] @classmethod def filter( cls, checkpoint: 'StateCkpt', strategy: Callable[[State], State] ): state = checkpoint.materialize() state = strategy(state) return cls( delta=state, parent=None, state_schema=checkpoint.state_schema, )