evofabric.core.graph._graph 源代码

# -*- coding: utf-8 -*-
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
import json
from typing import Callable, Dict, List, Optional, Set, Union

from pydantic import Field, PrivateAttr

from ..factory import BaseComponent, dump_schema_annotated_info, StateSchemaSerializable
from ._edge import cast_edge, ConditionEdgeSpec, EdgeSpec
from ._engine import GraphEngine
from ._engine_debugger import GraphEngineDebugger
from ._node import EndNode, GraphNodeSpec, NodeBase, StartNode
from ._streaming import get_stream_writer
from ..typing import DEFAULT_EDGE_GROUP, NodeActionMode, SpecialNode, State, StateSchema, GraphMode


[文档] class GraphBuilder(BaseComponent, StateSchemaSerializable): state_schema: type[StateSchema] = Field(description="State schema of this graph") """State schema of this graph""" _nodes_: Dict[str, GraphNodeSpec] = PrivateAttr(default_factory=dict) """Record all nodes in this graph""" _edges_: Dict[str, list] = PrivateAttr(default_factory=dict) """Record edges that fan-out from one node""" _multi_input_merge_: Dict[str, Dict[str, Callable[[list[State]], State]]] = PrivateAttr(default_factory=dict) """Record state merge strategy when a node have multiple predecessors""" _entry_set_: bool = PrivateAttr(default=False) """Mark if this graph has a start node.""" def _add_start_node(self): self._nodes_[SpecialNode.START_NODE.value] = GraphNodeSpec( node_name=SpecialNode.START_NODE.value, node=StartNode(), action_mode=NodeActionMode("any"), stream_writer=get_stream_writer(), ) def _add_end_node(self): self._nodes_[SpecialNode.END_NODE.value] = GraphNodeSpec( node_name=SpecialNode.END_NODE.value, node=EndNode(), action_mode=NodeActionMode("any"), stream_writer=get_stream_writer() ) def _has_end_node(self) -> bool: return SpecialNode.END_NODE.value in self._nodes_ def _has_node(self, node_name): return SpecialNode.is_special_node(node_name) or node_name in self._nodes_
[文档] def add_node( self, name: str, node: Union[Callable, NodeBase], action_mode: Union[NodeActionMode, str] = "any", multi_input_merge_strategy: dict[str, Callable[[List[State]], State]] = None, ): """ Add a node to the graph. Args: name: node name node: callable node, can be NodeBase or a simple python function action_mode: 'all' means this node should only be executed when all predecessors are executed in one channel. 'any' means this node will be executed instantly when one predecessor is executed. multi_input_merge_strategy: dict[channel_name(str), merge_strategy for each channel], this is useful for merge state when a node has multiple predecessors. Returns: """ if SpecialNode.is_special_node(name): raise ValueError(f"Node '{name}' is preserved by GraphBuilder.") if self._has_node(name): raise ValueError(f"Node '{name}' already exists.") action_mode = NodeActionMode(action_mode) self._nodes_[name] = GraphNodeSpec( node_name=name, node=node, action_mode=NodeActionMode(action_mode), stream_writer=get_stream_writer(), multi_input_merge_strategy=multi_input_merge_strategy )
[文档] def add_edge( self, source: str, target: str, group: str = DEFAULT_EDGE_GROUP, state_filter: Optional[Callable[[State], State]] = None, ): """ Connect two nodes. Args: source: source node name target: target node name group: edge group name state_filter: filter state function if needed Returns: """ if not self._has_node(source): raise ValueError(f"Node '{source}' not exists.") if not self._has_node(target): raise ValueError(f"Node '{target}' not exists.") self._edges_.setdefault(source, []).append( EdgeSpec( source=source, target=target, group=group, state_filter=state_filter, ) )
[文档] def add_condition_edge( self, source: str, router: Callable, possible_targets: Union[List[str], Set[str]], group: str = DEFAULT_EDGE_GROUP, ): """ Route next node by condition We support four modes of router function: * next target node name: return a single str * multiple next target node name: return a list of str * next target node name with specific state_filter: return a tuple[str, Callable] * multiple next target node name with specific state_filter: return a list of tuple[str, Callable] Args: source: router: possible_targets: group: Returns: """ if not self._has_node(source): raise ValueError(f"Node '{source}' not exists.") for target in possible_targets: if not self._has_node(target): raise ValueError(f"Node '{target}' not exists.") self._edges_.setdefault(source, []).append( ConditionEdgeSpec( source=source, router=router, group=group, possible_targets=list(set(possible_targets)) ) )
[文档] def set_entry_point(self, entry_name: str): self._add_start_node() self._edges_[SpecialNode.START_NODE.value] = [EdgeSpec( source=SpecialNode.START_NODE.value, target=entry_name)] self._entry_set_ = True
def _auto_connect_end(self): """Use this method, if u wish to connect all nodes without any fan-out to end""" self._add_end_node() wait_conn_nodes = [] for node_name, _ in self._nodes_.items(): if SpecialNode.is_special_node(node_name): continue if not self._edges_.get(node_name, None): wait_conn_nodes.append(node_name) for node_name in wait_conn_nodes: self.add_edge(node_name, SpecialNode.END_NODE.value)
[文档] def build( self, auto_conn_end: bool = True, max_turn: int = None, timeout: int = None, graph_mode: Union[GraphMode, str] = "run", db_file_path: str = "./.storage.db", db_name: str = "evofabric" ): assert self._entry_set_, "Graph must have an entry point, using set_entry_point('node_name') to assign one." if auto_conn_end: self._auto_connect_end() assert self._has_end_node(), "Graph must have at least one end node." graph_mode = GraphMode(graph_mode) if graph_mode == GraphMode.RUN: return GraphEngine( nodes=self._nodes_, edges=self._edges_, state_schema=self.state_schema, max_turn=max_turn, timeout=timeout ) else: return GraphEngineDebugger( nodes=self._nodes_, edges=self._edges_, state_schema=self.state_schema, max_turn=max_turn, timeout=timeout, db_file_path=db_file_path, )
[文档] def dumps(self, graph_name: str = "graph", version: str = "1.0") -> dict: """dump the existing components of this graph builder into config dict""" nodes_json = {} for k, v in self._nodes_.items(): try: v = v.model_dump() # pre-check if v can dump to json json.dumps(v) except Exception as e: raise ValueError(f"Cannot serialize node {k}, value: {v}") from e nodes_json[k] = v edges_json = {} for k, vs in self._edges_.items(): try: vs = [v.model_dump() for v in vs] # pre-check if v can dump to json json.dumps(vs) except Exception as e: raise ValueError(f"Cannot serialize edge of node {k}, value: {vs}") from e edges_json[k] = vs return { "graph_name": graph_name, "version": version, "state_schema": dump_schema_annotated_info(self.state_schema), "nodes": nodes_json, "edges": edges_json, "entry_set": self._entry_set_ }
[文档] def dump(self, save_path: str, graph_name: str = "graph", version: str = "1.0"): """dump the existing components of this graph builder into file""" with open(save_path, "w", encoding="utf-8") as f: json.dump(self.dumps(graph_name, version), f, indent=4, ensure_ascii=False)
[文档] @classmethod def load(cls, file_path: str) -> 'GraphBuilder': """load a built graph_builder from file.""" with open(file_path, "r", encoding="utf-8") as f: data = json.load(f) return cls.loads(data)
[文档] @classmethod def loads(cls, data: dict) -> 'GraphBuilder': """load a built graph_builder from a config dict""" graph_builder = cls(state_schema=data["state_schema"]) graph_builder._edges_ = {k: [cast_edge(v) for v in vs] for k, vs in data.get("edges", {}).items()} graph_builder._nodes_ = {k: GraphNodeSpec(**v) for k, v in data.get("nodes", {}).items()} graph_builder._entry_set_ = data.get("entry_set", False) return graph_builder