evofabric.core.graph._streaming 源代码
# -*- coding: utf-8 -*-
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
import asyncio
import contextvars
import threading
from contextlib import contextmanager
from typing import Any, Awaitable, Callable, Optional, Union
from pydantic import BaseModel
from ..factory import BaseComponent
from ...logger import get_logger
logger = get_logger()
def default_handler(x):
"""The default handler does nothing."""
...
_ON_MESSAGE: Callable[[dict], Union[None, Awaitable[None]]] = default_handler
_LOCK = threading.Lock()
[文档]
def set_streaming_handler(
callback: Callable[[dict], Union[None, Awaitable[None]]]
) -> None:
"""Register the callback of stream messages"""
global _ON_MESSAGE
with _LOCK:
_ON_MESSAGE = callback
[文档]
class StreamCtx(BaseModel):
# node level context
node_name: Optional[str] = None
call_id: Optional[str] = None
# tool level context
tool_name: Optional[str] = None
tool_call_id: Optional[str] = None
# other context can be saved as meta info
meta: Optional[dict] = None
[文档]
def __bool__(self) -> bool:
"""If all attr is empty, return False, otherwise return true."""
return any(
[self.node_name, self.call_id, self.tool_name, self.tool_call_id]
)
[文档]
def __repr__(self) -> str:
"""Return repr of StreamCtx"""
parts = []
if self.node_name or self.call_id:
parts.append(
f"Node(name={self.node_name or '?'}, id={self.call_id or '?'})"
)
if self.tool_name or self.tool_call_id:
parts.append(
f"Tool(name={self.tool_name or '?'}, id={self.tool_call_id or '?'})"
)
return " -> ".join(parts) if parts else "StreamCtx(empty)"
_STREAM_CTX: contextvars.ContextVar[StreamCtx] = contextvars.ContextVar("ctx", default=StreamCtx())
[文档]
def current_ctx() -> StreamCtx:
"""Return current context"""
return _STREAM_CTX.get()
[文档]
@contextmanager
def stream_writer_env(ctx_updates: StreamCtx):
"""
Generate stream context information for StreamWriter
Example:
with stream_writer_env(StreamCtx(node_name='NodeA')):
# ...
"""
parent_ctx = _STREAM_CTX.get()
update_data = ctx_updates.model_dump(exclude_unset=True)
if ctx_updates.meta is not None:
merged_meta = (parent_ctx.meta or {}).copy()
merged_meta.update(ctx_updates.meta)
update_data["meta"] = merged_meta
new_ctx = parent_ctx.model_copy(update=update_data)
token = _STREAM_CTX.set(new_ctx)
try:
yield
finally:
_STREAM_CTX.reset(token)
[文档]
class StreamWriter(BaseComponent):
[文档]
@staticmethod
def put(payload: Any) -> None:
"""Put streaming msg into stream writer and trigger the msg handler"""
ctx = current_ctx()
envelope = {
**ctx.model_dump(exclude_unset=True),
"payload": payload,
}
handler = _ON_MESSAGE
if asyncio.iscoroutinefunction(handler):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.get_event_loop()
loop.create_task(handler(envelope))
else:
handler(envelope)
_G_STREAM_WRITER = StreamWriter()
[文档]
def get_stream_writer():
"""Return global stream writer"""
return _G_STREAM_WRITER