evofabric.core.graph._engine_debugger 源代码

# -*- coding: utf-8 -*-
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.

import asyncio
import sqlite3
import uuid
from collections import deque
from copy import deepcopy
from typing import Any, Dict, List, Optional, Set

from pydantic import BaseModel, Field, PrivateAttr

from ._edge import ConditionEdgeSpec, EdgeSpec, EdgeSpecBase
from ._engine import GraphEngine, RunTimeTask
from ._state import StateCkpt
from ..typing import DEFAULT_EDGE_GROUP, NodeActionMode, SpecialNode
from ...logger import get_logger

logger = get_logger()

BranchFinished = "branch_finished"


class TreeNode(BaseModel):
    """Represents a node in the RuntimeTaskTree"""

    node_name: str
    """node name"""

    task: Optional[RunTimeTask] = None
    """task of this node"""

    children: Dict[str, 'TreeNode'] = Field(default_factory=dict)
    """children of this node"""

    uuid: str = Field(default_factory=lambda: str(uuid.uuid4()))
    """unique id for tree node"""

    _parent: Dict[str, 'TreeNode'] = PrivateAttr(default_factory=dict)
    """parent of this node"""

    def __repr__(self):
        parent_name = self._parent.keys() if self._parent else "None"
        return (f"TreeNode(uuid='{self.uuid[:8]}...', name='{self.node_name}', "
                f"parent='{parent_name}', children={list(self.children.keys())})")

    def __hash__(self) -> int:
        return hash(self.uuid)

    def __eq__(self, other):
        return self.uuid == other.uuid

    @property
    def parent(self):
        return self._parent

    def add_child(self, node: 'TreeNode'):
        self.children[node.node_name] = node
        node._parent = self


class RuntimeTaskTree(BaseModel):
    """
    A doubly-linked tree/graph structure for managing task execution flow.
    Supports backward traversal and branch pruning.
    """

    db_path: str = Field(default=".db_storage", exclude=True)
    """Persistence database save path"""

    _root: Optional[TreeNode] = PrivateAttr(default=None)
    """Root node of this tree"""

    _leaf_nodes: Dict[str, TreeNode] = PrivateAttr(default_factory=dict)
    """Stores all nodes that currently have no children"""

    _uuid_map: Dict[str, TreeNode] = PrivateAttr(default_factory=dict)
    """A quick lookup map from UUID to Node object"""

    _conn: Optional[sqlite3.Connection] = PrivateAttr(default=None)
    """DB Connection object"""

    def model_post_init(self, __context: Any) -> None:
        self._conn = sqlite3.connect(self.db_path)
        self._create_tables()

    def clear_tree(self):
        self._root = None
        self._leaf_nodes = {}
        self._uuid_map = {}

    def _create_tables(self):
        cur = self._conn.cursor()
        cur.execute("""
        CREATE TABLE IF NOT EXISTS trees (
            root_uuid TEXT PRIMARY KEY,
            data TEXT NOT NULL
        )
        """)
        self._conn.commit()

    def save_tree(self):
        """Save entire tree to database"""
        cur = self._conn.cursor()
        cur.execute("""
            INSERT OR REPLACE INTO trees (root_uuid, data)
            VALUES (?, ?)
        """, ("HEAD", self._root.model_dump_json()))
        self._conn.commit()

    def load_tree(self) -> Optional[object]:
        """Load entire tree from database"""

        cur = self._conn.cursor()
        cur.execute("SELECT data FROM trees WHERE root_uuid = ?", ("HEAD",))
        row = cur.fetchone()
        if not row:
            return None
        data = row[0]
        self._root = TreeNode.model_validate_json(data)
        self._rebuild_tree_params()

    def _rebuild_tree_params(self):
        # rebuild leaf_nodes and _uuid_map
        def dfs(node: TreeNode):
            self._uuid_map[node.uuid] = node
            if len(node.children) == 0:
                self._leaf_nodes[node.uuid] = node
            else:
                for child in node.children.values():
                    dfs(child)

        dfs(self._root)

    def clear(self):
        """Clear database"""
        cur = self._conn.cursor()
        cur.execute("DELETE FROM trees")
        self._conn.commit()

    def close(self):
        """Close database connection"""
        self._conn.close()

    def get_node_by_uuid(self, node_uuid: str) -> Optional[TreeNode]:
        """Helper to find a node by its UUID."""
        return self._uuid_map.get(node_uuid)

    def get_leaf_node_uuid_by_node_name(self, node_name: str) -> str:
        """Helper to find a leaf node uuid by its node name."""
        for node_uuid, node in self._leaf_nodes.items():
            if node.node_name == node_name:
                return node_uuid
        raise ValueError(f"No leaf node uuid found for node name {node_name}")

    def _traverse_depth_first(self, node: TreeNode) -> List[TreeNode]:
        """Internal DFS for name lookup sorting."""
        if not node:
            return []

        # Simple DFS traversal for guaranteed root-to-leaf order
        # This is a robust way to generate a list that respects tree order
        # and can be searched.
        result = [node]
        for child in node.children.values():
            result.extend(self._traverse_depth_first(child))
        return result

    def add_node(self, node_name: str, task: Any, parent_uuid: Optional[str] = None) -> TreeNode:
        """
        Adds a new leaf node to the tree.
        If parent_uuid is None, it sets the root node (only on first call).
        """
        new_node = TreeNode(node_name=node_name, task=task)
        self._uuid_map[new_node.uuid] = new_node

        if self._root is None:
            # First node becomes the root
            self._root = new_node
            self._leaf_nodes[new_node.uuid] = new_node
            return new_node

        parent_node = self.get_node_by_uuid(parent_uuid)
        if not parent_node:
            raise ValueError(f"Parent node with UUID '{parent_uuid}' not found.")

        # Establish bidirectional link
        new_node.parent[parent_node.uuid] = parent_node
        parent_node.children[new_node.uuid] = new_node

        # Update leaf nodes: parent is no longer a leaf
        if parent_node.uuid in self._leaf_nodes:
            self._leaf_nodes.pop(parent_node.uuid)

        # The new node is now the new leaf
        self._leaf_nodes[new_node.uuid] = new_node

        return new_node

    def merge_nodes(self, from_nodes: List[TreeNode], to_node: TreeNode):
        """
        merge nodes,focus on all mode node, when compute process is a -> b -> c and a -> d -> c, we must ensure c is
        the same runtime node, but runtime node c will only save 1 trace route, but this process will not influence backward
        """
        for from_node in from_nodes:
            if from_node.uuid != to_node.uuid:
                from_node_parents = from_node.parent.values()
                # source node parent add and merge -> rm origin node
                for from_node_parent in from_node_parents:
                    from_node_parent.children[to_node.uuid] = to_node
                    from_node_parent.children.pop(from_node.uuid)
                    # add to node parent
                    to_node.parent[from_node_parent.uuid] = from_node_parent
                #  rm leaf nodes
                if from_node.uuid in self._leaf_nodes:
                    self._leaf_nodes.pop(from_node.uuid)

    def backtrack_and_prune(self, node_uuid: str) -> Optional[TreeNode]:
        """
        Backtracks from the specified node (inclusive),
        effectively pruning all subsequent nodes and resetting the leaf/path.
        Returns the new leaf node (which is the specified node's parent).
        """
        target_node = self.get_node_by_uuid(node_uuid)
        if not target_node:
            logger.warning(f"Node with UUID '{node_uuid}' not found for backtracking.")

        # Prune logic:
        # 1. Recursively find and delete all subsequent nodes (children and their descendants)
        nodes_to_remove: List[TreeNode] = []

        def _collect_and_remove(_cl_node: TreeNode):

            # Collect children first
            for cld in _cl_node.children.values():
                _collect_and_remove(cld)

            # Now remove the node itself
            nodes_to_remove.append(_cl_node)

        # Start collection from the children of the target node
        for child in target_node.children.values():
            _collect_and_remove(child)

        # 2. Perform the actual removal from maps and sets
        for node in nodes_to_remove:
            self._uuid_map.pop(node.uuid, None)
            self._leaf_nodes.pop(node.uuid, None)

        # 3. Clear the children list of the target node (the new leaf)
        target_node.children.clear()

        # 4. Update leaf nodes: the target node is now a leaf
        self._leaf_nodes[target_node.uuid] = target_node

        return target_node

    def find_nodes_by_trace_route(self, trace_route: List[str]) -> TreeNode:
        """
         according to runtime task route for node seeking, runtime task route update to uuid
        """
        if trace_route[0] != self._root.uuid:
            raise ValueError(f"trace_route[0] is wrong!")
        cur_node = self._root
        trace_route = deque(trace_route)
        trace_route.popleft()
        while trace_route:
            next_node_uuid = trace_route.popleft()
            if next_node_uuid not in cur_node.children:
                raise ValueError(f"Node with UUID '{next_node_uuid=}' not found in {cur_node.children=}.")
            cur_node = cur_node.children[next_node_uuid]

        return cur_node

    def in_leaf(self, node_uuid: str) -> bool:
        return node_uuid in self._leaf_nodes

    # Helper to print the tree structure
    def print_tree(self, node: Optional[TreeNode] = None, level: int = 0):
        if node is None:
            node = self._root
            if not node:
                logger.info("Tree is empty.")
                return

        logger.info('  ' * level + f"- {node.node_name} ({node.uuid[:4]}...)")
        for child in node.children.values():
            self.print_tree(child, level + 1)


[文档] class GraphEngineDebugger(GraphEngine): db_file_path: str = Field(default=".state_storage.db") _trace_tree: RuntimeTaskTree = PrivateAttr(default=None) _bp_set: set = PrivateAttr(default=None) _inverse_graph: Dict[str, Set[str]] = PrivateAttr(default=None) _waiting_inputs_backup: Dict[str, Any] = PrivateAttr(default=None) _true_waiting_inputs: Dict[str, Any] = PrivateAttr(default=None) _merged_mapping: Dict[str, Any] = PrivateAttr(default=None) def _build_inverse_graph(self): self._inverse_graph = {} for source, edges in self.edges.items(): for edge in edges: if isinstance(edge, ConditionEdgeSpec): for pos_tar in edge.possible_targets: if pos_tar not in self._inverse_graph: self._inverse_graph[pos_tar] = set() self._inverse_graph[pos_tar].add(source) elif isinstance(edge, EdgeSpec): if edge.target not in self._inverse_graph: self._inverse_graph[edge.target] = set() self._inverse_graph[edge.target].add(source)
[文档] def model_post_init(self, context: Any, /) -> None: self._trace_tree = RuntimeTaskTree() self._bp_set = set() self._waiting_inputs_backup = {} self._merged_mapping = {} self.reset() self._build_inverse_graph()
def save_status_to_db(self): self._trace_tree.save_tree() def load_status_from_db(self): self._trace_tree.load_tree()
[文档] def reset(self): self._check_can_running() self._state_root = StateCkpt(delta=None, parent=None, state_schema=self.state_schema) self._queue = asyncio.Queue() self._waiting_inputs = self._analyze_graph() self._true_waiting_inputs = self._analyze_graph_true_waiting() self._output_channels: asyncio.Queue[RunTimeTask] = asyncio.Queue() self._node_exec_cnt = 0 self._trace_tree.clear_tree() self._waiting_inputs_backup.clear() self._merged_mapping.clear()
def _check_can_running(self): if self._is_running: raise RuntimeError("This graph is still running, cannot reset.") if self.state_schema is None: raise RuntimeError("A schema of state must be assigned.") def _analyze_graph_true_waiting(self): config = dict() for name, edges in self.edges.items(): for edge in edges: edge: EdgeSpecBase for target in edge.get_possible_targets(): config.setdefault(target, {}).setdefault(edge.group, {}) config[target][edge.group][name] = None return config
[文档] def set_breakpoint(self, /, node_name_bp=None, condition_bp=None, condition=None): """ set breakpoint: 1. break on node 2. feature: conditional breakpoint, split to input state and output state breakpoint """ if node_name_bp is None: raise RuntimeError("condition_bp is not supported.") else: self._bp_set.add(node_name_bp)
[文档] def clear_breakpoint(self, /, node_name_bp=None, condition_bp=None, condition=None): """ clear breakpoint """ if node_name_bp is None: raise RuntimeError("condition_bp is not supported.") else: self._bp_set.remove(node_name_bp)
[文档] def clear_all_breakpoint(self): self._bp_set.clear()
def _all_task_finished(self): for task in self._trace_tree._leaf_nodes.values(): if task.node_name != BranchFinished: return False return True
[文档] async def resume(self, running_queue=None): """ resume program """ if self._is_running: raise RuntimeError("This graph is still running, cannot run again.") if running_queue is None: running_queue = [node.task for node in self._trace_tree._leaf_nodes.values() if node.node_name not in self._bp_set] one_step_result = None while running_queue: one_step_result, running_queue = await self.run_one_step(running_queue) self._is_running = False return one_step_result
def _executable_node(self, node_name_from: str, node_name_to: str): if node_name_from == node_name_to: return True bfs_queue = deque([node_name_from]) visited = set() while bfs_queue: node_name = bfs_queue.popleft() visited.add(node_name) for next_step_edge in self.edges[node_name]: if isinstance(next_step_edge, EdgeSpec): next_step_node_name = next_step_edge.target if next_step_node_name == node_name_to: return True if next_step_node_name not in visited: bfs_queue.append(next_step_node_name) elif isinstance(next_step_edge, ConditionEdgeSpec): for next_step_node_name in next_step_edge.possible_targets: if next_step_node_name == node_name_to: return True if next_step_node_name not in visited: bfs_queue.append(next_step_node_name) return False def _get_run_nodes(self, node_name: str): all_ava_nodes = [node for node in self._trace_tree._leaf_nodes.values() if node.node_name != node_name] run_tasks = [] for node in all_ava_nodes: if self._executable_node(node.node_name, node_name): run_tasks.append(node.task) return run_tasks def _get_step_over_nodes(self, runtime_node: TreeNode, candidate_nodes: List[TreeNode]): node = self.nodes[runtime_node.node_name] if node.action_mode == NodeActionMode.ANY or len(self._inverse_graph.get(node.node_name, [])) <= 1: return [runtime_node] else: run_nodes = [runtime_node] for cand_node in candidate_nodes: if cand_node.node_name != node.node_name and self._executable_node(cand_node.node_name, node.node_name): run_nodes.append(cand_node) return run_nodes
[文档] async def step_over(self, node_uuid=None): """ step over current breakpoint, program will just execute uuid_node when node_uuid is not None else will execute current next step """ if self._trace_tree.in_leaf(node_uuid): leaf_node = self._trace_tree._leaf_nodes[node_uuid] else: leaf_node = None run_nodes = [] pre_nodes = {} new_breakpoints = set() # set step over nodes and previous nodes node_name list if leaf_node is not None: run_nodes = [leaf_node] else: logger.warning(f"node_name not specified, so will step over all current nodes") for node in self._trace_tree._leaf_nodes.values(): if node.node_name == BranchFinished: continue elif node.node_name == SpecialNode.END_NODE.value: run_nodes.append(node) continue run_nodes.append(node) for node in run_nodes: if self.nodes[node.node_name].action_mode == NodeActionMode.ALL: step_over_nodes = self._get_step_over_nodes( node, candidate_nodes=[cand_node for cand_node in self._trace_tree._leaf_nodes.values()] ) if len(step_over_nodes) > 1: # only all mode will collect pre nodes if step_over_nodes[0] not in pre_nodes: pre_nodes[step_over_nodes[0]] = [] pre_nodes[step_over_nodes[0]].extend(step_over_nodes[1:]) # merge all dependency head_nodes = list(pre_nodes.keys()) for idx, node in enumerate(head_nodes): step_over_nodes = self._get_step_over_nodes( node, candidate_nodes=head_nodes[:idx] + head_nodes[idx + 1:] ) if len(step_over_nodes) > 1: for need_merge_node in step_over_nodes[1:]: pre_nodes[step_over_nodes[0]].extend(pre_nodes[need_merge_node]) # pre_nodes remove duplicates for head_node, cands in pre_nodes.items(): new_cands = [] has_uuid = set() for cand in cands: if cand.uuid not in has_uuid: new_cands.append(cand) has_uuid.add(cand.uuid) pre_nodes[head_node] = new_cands # if previous nodes not empty, will run the previous nodes to breakpoint and clear new breakpoints if len(pre_nodes) > 0: resume_queue = [] for bk_node, nodes in pre_nodes.items(): if bk_node.node_name not in self._bp_set: new_breakpoints.add(bk_node.node_name) self.set_breakpoint(bk_node.node_name) resume_queue.extend(nodes) await self.resume([ node.task for node in resume_queue ]) # clear new breakpoints for bk in new_breakpoints: self.clear_breakpoint(bk) # if current node have been merged, the node will replace to merged node for idx, node in enumerate(run_nodes): if node.uuid in self._merged_mapping: run_nodes[idx] = self._merged_mapping[node.uuid] result, _ = await self.run_one_step([ node.task for node in run_nodes ]) return result
def _restore_waiting_input(self, end_nodes: List[TreeNode]): """ given restore end node list, restore waiting_input """ start_nodes: List[TreeNode] = [] def _dfs(node: TreeNode): if len(node.children) == 0: start_nodes.append(node) else: for child in node.children.values(): _dfs(child) for node in end_nodes: _dfs(node) end_uuids = set([node.uuid for node in end_nodes]) visited = set() while start_nodes: node = start_nodes.pop() if node.uuid not in end_uuids and node.uuid not in visited: for parent in node.parent.values(): self._waiting_inputs[node.node_name][node.task.edge_group][parent.node_name] = ( self._waiting_inputs_backup["__".join([parent.uuid, node.uuid])]) start_nodes.append(parent) visited.add(node.uuid)
[文档] def restore_step(self, node_uuid: str = None): """ restore last step, program will restore last step when node_name is None """ restore_nodes = [] restore_waiting_input = [] if node_uuid is None: for node in deepcopy(list(self._trace_tree._leaf_nodes.values())): for node_parent in node.parent.values(): restore_nodes.append(node_parent) if self.nodes[node.node_name].action_mode == NodeActionMode.ALL: # ANY Mode Node will not restore, because the node runtime must through parent node computation restore_waiting_input.append(node_parent) else: restore_nodes.append(self._trace_tree.get_node_by_uuid(node_uuid)) self._restore_waiting_input(restore_waiting_input) for node_parent in restore_nodes: self._trace_tree.backtrack_and_prune(node_parent.uuid)
async def _process_node( self, runtime_task: RunTimeTask ) -> List[RunTimeTask]: """Process current node and return successor nodes""" node_name = runtime_task.node_name if self.max_turn and self._node_exec_cnt >= self.max_turn: logger.info(f"Maximum invocation count exceeded: {self._node_exec_cnt}") return [] self._node_exec_cnt += 1 node = self.nodes[node_name] # get all runtime value from all predecessors and restore full state input_runtimes = self._get_node_inputs(runtime_task) if isinstance(input_runtimes, list): runtime = self._merge_state(input_runtimes) else: runtime = self._filter_state(input_runtimes) parent_node = self._trace_tree.find_nodes_by_trace_route(runtime_task.trace_route) if SpecialNode.is_end_node(node_name): await self._output_channels.put(runtime) self._trace_tree.add_node(node_name=BranchFinished, task=None, parent_uuid=parent_node.uuid) return [] full_state = runtime.state_ckpt.materialize() state_delta = await node(full_state) new_state_ckpt = StateCkpt(delta=state_delta, parent=runtime.state_ckpt, state_schema=self.state_schema) next_tasks = [] for edge in self.edges[node_name]: targets = edge.get_targets(new_state_ckpt.materialize()) for target, state_filter in targets: next_task = RunTimeTask( node_name=target, state_ckpt=new_state_ckpt, edge_group=edge.group, state_filter=state_filter, predecessor=node_name, trace_route=runtime_task.trace_route ) new_node = self._trace_tree.add_node( node_name=target, task=next_task, parent_uuid=parent_node.uuid ) next_task.trace_route += [new_node.uuid] next_tasks.append(next_task) return next_tasks def _backup_waiting_input(self, runtime_task: RunTimeTask, value): if len(runtime_task.trace_route) < 2: backup_key = "__" + runtime_task.trace_route[-1] else: backup_key = "__".join(runtime_task.trace_route[-2:]) self._waiting_inputs_backup[backup_key] = value def _update_waiting_inputs(self, runtime_task: RunTimeTask): self._true_waiting_inputs[runtime_task.node_name][runtime_task.edge_group][ runtime_task.predecessor] = runtime_task if runtime_task.node_name not in self._waiting_inputs: return self._waiting_inputs[runtime_task.node_name][runtime_task.edge_group][ runtime_task.predecessor] = runtime_task def _is_predecessor_all_ready(self, runtime_task: RunTimeTask): self._true_waiting_inputs[runtime_task.node_name][runtime_task.edge_group][ runtime_task.predecessor] = runtime_task if runtime_task.node_name not in self._waiting_inputs: return True self._waiting_inputs[runtime_task.node_name][runtime_task.edge_group][runtime_task.predecessor] = runtime_task if all([x is not None for x in self._waiting_inputs[runtime_task.node_name][runtime_task.edge_group].values()]): # add merge nodes process all_need_merge_nodes_uuid = [ rt.trace_route[-1] for rt in self._waiting_inputs[runtime_task.node_name][runtime_task.edge_group].values() ] target_uuid = runtime_task.trace_route[-1] from_nodes = [self._trace_tree.get_node_by_uuid(uid) for uid in all_need_merge_nodes_uuid] to_node = self._trace_tree.get_node_by_uuid(target_uuid) self._trace_tree.merge_nodes( from_nodes=from_nodes, to_node=to_node, ) # updating _waiting_inputs_backup for from_node in from_nodes: from_parent_node_uuid = from_node.task.trace_route[-2] to_node_uuid = to_node.uuid if from_node.uuid != to_node_uuid: self._merged_mapping[from_node.uuid] = to_node self._waiting_inputs_backup["__".join([from_parent_node_uuid, to_node_uuid])] = ( self._waiting_inputs_backup)["__".join(from_node.task.trace_route[-2:])] return True return False
[文档] async def run_one_step(self, running_queue): self._check_can_running() # This is where we will collect tasks for the parallel run results_from_gather = await asyncio.wait_for( asyncio.gather( *[self._process_node(runtime_task) for runtime_task in running_queue], ), timeout=self.timeout ) next_batch_candidates = [] for runtime_task_list in results_from_gather: for runtime_task in runtime_task_list: if not SpecialNode.is_end_node(runtime_task.node_name): self._backup_waiting_input( runtime_task, self._true_waiting_inputs[runtime_task.node_name][runtime_task.edge_group][ runtime_task.predecessor] ) if self._get_node_action_mode(runtime_task.node_name) == NodeActionMode.ALL: self._update_waiting_inputs(runtime_task) if self._is_predecessor_all_ready(runtime_task): if runtime_task.node_name not in self._bp_set: next_batch_candidates.append(runtime_task) self._is_running = False if self._all_task_finished(): output = await self.get_output() else: output = None return output, next_batch_candidates
[文档] async def debug(self, inputs: Dict): self._check_can_running() self.reset() running_queue: List[RunTimeTask] = [RunTimeTask( node_name=SpecialNode.START_NODE.value, state_ckpt=self._init_state_ckpt(inputs), edge_group=DEFAULT_EDGE_GROUP, state_filter=None, trace_route=[SpecialNode.START_NODE.value], )] new_node = self._trace_tree.add_node( node_name=SpecialNode.START_NODE.value, task=running_queue[-1], ) running_queue[-1].trace_route = [new_node.uuid] one_step_result = None while running_queue: one_step_result, running_queue = await self.run_one_step(running_queue) self._is_running = False return one_step_result
[文档] def change_output(self, from_node_uuid: str, to_node_uuid: str, change_key: str, change_value: Any): if to_node_uuid not in self._trace_tree._leaf_nodes: raise RuntimeError(f"change output only support current node!") target_node_name = self._trace_tree.get_node_by_uuid(to_node_uuid).node_name source_node_name = self._trace_tree.get_node_by_uuid(from_node_uuid).node_name source_runtime_node = self._trace_tree.get_node_by_uuid(from_node_uuid) if self.nodes[target_node_name].action_mode != NodeActionMode.ANY: state_ckpt = self._waiting_inputs[target_node_name][source_runtime_node.task.edge_group][ source_node_name].state_ckpt setattr(state_ckpt.delta['messages'][-1], change_key, change_value) setattr(state_ckpt.materialized_state_cache.messages[-1], change_key, change_value)