evofabric.core.graph._state_update 源代码
# -*- coding: utf-8 -*-
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
import uuid
from typing import Any, Callable, Dict, List
from ..factory import safe_get_attr, safe_set_attr
from ..typing import (
cast_state_message, MISSING, StateMessage
)
[文档]
class StateUpdater:
"""
register and manage state update methods
usage:
@StateUpdateStrategy.register("overwrite")
def overwrite(old: Any, new: Any) -> Any:
return new
strategy = StateUpdateStrategy.get("overwrite")
merged = strategy(old_state, new_state)
"""
_strategies: Dict[str, Callable[[Any, Any], Any]] = {}
[文档]
@classmethod
def register(cls, name: str) -> Callable[[Callable], Callable]:
def decorator(func: Callable[[Any, Any], Any]) -> Callable[[Any, Any], Any]:
if name in cls._strategies:
raise KeyError(f"Strategy '{name}' already registered")
cls._strategies[name] = func
return func
return decorator
[文档]
@classmethod
def get(cls, name: str) -> Callable[[Any, Any], Any]:
try:
return cls._strategies[name]
except KeyError:
raise KeyError(f"Unknown strategy '{name}'. "
f"Available: {list(cls._strategies)}") from None
[文档]
@classmethod
def list_strategies(cls) -> List[str]:
return list(cls._strategies.keys())
[文档]
@classmethod
def registered(cls, name: str) -> bool:
return name in cls._strategies
@StateUpdater.register('overwrite')
def _overwrite_state_update_strategy(old: Any = MISSING, new: Any = MISSING) -> Any:
if old is MISSING and new is MISSING:
return MISSING
if new is MISSING:
return old
return new
@StateUpdater.register('append_messages')
def _append_messages(old: List[StateMessage] = MISSING, new: List[StateMessage] = MISSING) -> List[StateMessage]:
"""
Append messages
"""
if old is MISSING:
old = []
if new is MISSING:
new = []
id_map = set()
for msg in old:
if not isinstance(msg, StateMessage):
msg = cast_state_message(msg)
if not safe_get_attr(msg, "msg_id"):
safe_set_attr(msg, "msg_id", str(uuid.uuid4()))
id_map.add(safe_get_attr(msg, "msg_id"))
for msg in new:
if not isinstance(msg, StateMessage):
msg = cast_state_message(msg)
if not safe_get_attr(msg, "msg_id"):
safe_set_attr(msg, "msg_id", str(uuid.uuid4()))
result = list([x for x in old] + [x for x in new if safe_get_attr(x, "msg_id") not in id_map])
result = [cast_state_message(x) for x in result]
return result