# -*- coding: utf-8 -*-
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
import asyncio
import inspect
import traceback
import uuid
from typing import Any, Callable, cast, Dict, List, Optional, Type, Union
from pydantic import Field, field_serializer, field_validator, PrivateAttr
from ._streaming import stream_writer_env, StreamCtx, StreamWriter
from ._utils import _effective_func, _has_stream_writer, _make_class_name
from ..factory import BaseComponent, ComponentFactory, get_func_serializer, safe_get_attr, safe_set_attr
from ..typing import NodeActionMode, State, StateDelta
from ...logger import get_logger
logger = get_logger()
[文档]
class NodeBase(BaseComponent):
...
[文档]
class SyncNode(NodeBase):
[文档]
def __call__(self, state: State) -> StateDelta:
...
[文档]
class AsyncNode(NodeBase):
[文档]
async def __call__(self, state: State) -> StateDelta:
...
[文档]
class SyncStreamNode(NodeBase):
[文档]
def __call__(self, state: State, stream_writer: StreamWriter) -> StateDelta:
"""A sync stream node.
Stream messages can be collected by stream.put(data)
For example:
```python
class SyncStreamNode(Protocol):
def __call__(self, state: State, stream_writer: StreamWriter) -> StateDelta:
msg = []
for stream in llm.chat():
msg.append(stream)
stream_writer.put(stream)
return {"msg": "".join(msg)}
```
"""
...
[文档]
class AsyncStreamNode(NodeBase):
[文档]
async def __call__(self, state: State, stream_writer: StreamWriter) -> StateDelta:
"""
A async stream node.
Stream messages can be collected by stream.put(data)
For example:
```python
class SyncStreamNode(Protocol):
async def __call__(self, state: State, stream_writer: StreamWriter) -> StateDelta:
msg = []
for stream in llm.chat():
msg.append(stream)
stream_writer.put(stream)
return {"msg": "".join(msg)}
```
"""
...
class _SyncPlainFactory:
def __init__(self, func: Callable[..., StateDelta]) -> None:
self.func = func
def build(self) -> NodeBase:
cls = cast(Type[SyncNode], type(
_make_class_name("_SyncNode"),
(SyncNode,),
{"__call__": self._call})
)
return cls()
def _call(self, state: State) -> StateDelta:
return self.func(state)
class _SyncStreamFactory:
def __init__(self, func: Callable[..., StateDelta]) -> None:
self.func = func
def build(self) -> NodeBase:
cls = cast(Type[SyncStreamNode], type(
_make_class_name("_SyncStreamNode"),
(SyncStreamNode,),
{"__call__": self._call})
)
return cls()
def _call(self, state: State, stream_writer: StreamWriter) -> StateDelta:
return self.func(state, stream_writer)
class _AsyncPlainFactory:
def __init__(self, func: Callable[..., Any]) -> None:
self.func = func
def build(self) -> NodeBase:
cls = cast(Type[AsyncNode], type(
_make_class_name("_AsyncNode"),
(AsyncNode,),
{"__call__": self._call})
)
return cls()
async def _call(self, state: State) -> StateDelta:
return await self.func(state)
class _AsyncStreamFactory:
def __init__(self, func: Callable[..., Any]) -> None:
self.func = func
def build(self) -> NodeBase:
cls = cast(Type[AsyncStreamNode], type(
_make_class_name("_AsyncStreamNode"),
(AsyncStreamNode,),
{"__call__": self._call})
)
return cls()
async def _call(self, state: State, stream_writer: StreamWriter) -> StateDelta:
return await self.func(state, stream_writer)
[文档]
def callable_to_node(callable_obj: Callable[..., Any]) -> NodeBase:
"""Convert any callable object to four subclass of node base."""
func = _effective_func(callable_obj)
is_async = inspect.iscoroutinefunction(func)
need_stream = _has_stream_writer(func)
if is_async and need_stream:
factory_cls = _AsyncStreamFactory
elif is_async and not need_stream:
factory_cls = _AsyncPlainFactory
elif not is_async and need_stream:
factory_cls = _SyncStreamFactory
else:
factory_cls = _SyncPlainFactory
return factory_cls(func).build()
class StartNode(SyncNode):
def __call__(self, state: State) -> StateDelta:
logger.info(f"[Node start] Input: {state}")
return {}
class EndNode(SyncNode):
def __call__(self, state: State) -> StateDelta:
return {}
[文档]
class GraphNodeSpec(BaseComponent):
node: Union[Callable, NodeBase]
node_name: str
action_mode: NodeActionMode = Field(default=NodeActionMode.ALL)
stream_writer: Optional[StreamWriter] = Field(default=None)
multi_input_merge_strategy: Optional[Dict[str, Callable[[List[State]], State]]] = Field(default=None)
node_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
_active: bool = PrivateAttr(default=False)
_original_node_callable: Optional[Callable] = PrivateAttr(default=None)
def model_post_init(self, context: Any, /) -> None:
if callable(self.node) and not isinstance(self.node, NodeBase):
self._original_node_callable = self.node
self.node = callable_to_node(self.node)
@field_serializer('node')
def serialize_node(self, _value: NodeBase, _info) -> Dict[str, Any]:
if self._original_node_callable:
serialized_func = get_func_serializer().serialize(self._original_node_callable)
return {'type': 'callable', 'data': serialized_func}
else:
return {
'type': 'node_instance', 'data': self.node.model_dump(), "node_class_name": self.node.__class__.__name__
}
@field_validator('node', mode='before')
@classmethod
def deserialize_node(cls, v: Any) -> Any:
if isinstance(v, dict) and 'type' in v and 'data' in v:
if v['type'] == 'callable':
return get_func_serializer().deserialize(v['data'])
elif v['type'] == 'node_instance':
return ComponentFactory.create(v["node_class_name"], **v['data'])
return v
@field_serializer('multi_input_merge_strategy')
def serialize_merge_strategy(self, strategy: Optional[Dict[str, Callable]]) -> Optional[Dict[str, str]]:
if strategy is None:
return strategy
return {key: get_func_serializer().serialize(func) for key, func in strategy.items()}
@field_validator('multi_input_merge_strategy', mode='before')
@classmethod
def deserialize_merge_strategy(cls, v: Any) -> Optional[Dict[str, Callable]]:
if v is None:
return None
if isinstance(v, dict):
return {key: v if callable(v) else get_func_serializer().deserialize(v) for key, v in v.items()}
return v
@property
def is_active(self) -> bool:
"""Is this node active"""
return self._active
def _inject_node_name(self, delta):
if _msg := safe_get_attr(delta, "messages", []):
for msg in _msg:
safe_set_attr(msg, "node_name", self.node_name)
return delta
[文档]
async def __call__(self, state: State, **kwargs) -> StateDelta:
"""Run node with asyncio locker and set activate status"""
async with self._lock:
self._active = True
try:
delta = await self._run(state)
try:
delta = self._inject_node_name(delta)
except Exception as e:
logger.warning(
f"Injecting node name in `messages` failed, reason: {e}\ntraceback: {traceback.format_exc()}")
logger.info(f"[Node {self.node_name}] Output: {delta}")
return delta
finally:
self._active = False
async def _run(self, state: State):
"""Run node with different mode and set stream writer environment"""
with stream_writer_env(StreamCtx(call_id=str(uuid.uuid4()), node_name=self.node_name)):
if isinstance(self.node, SyncNode):
return self.node(state)
elif isinstance(self.node, AsyncNode):
return await self.node(state)
elif isinstance(self.node, SyncStreamNode):
return self.node(state, stream_writer=self.stream_writer)
elif isinstance(self.node, AsyncStreamNode):
return await self.node(state, stream_writer=self.stream_writer)
raise TypeError("Node type not recognized")