evofabric.core.factory._factory 源代码
# -*- coding: utf-8 -*-
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
from __future__ import annotations
from typing import (
Any, ClassVar, Dict, Type, TypeVar
)
from pydantic import BaseModel, ConfigDict
from pydantic_core import core_schema
from typing_extensions import Unpack
from ...logger import get_logger
logger = get_logger()
T = TypeVar("T")
[文档]
class ComponentFactory:
_registry: ClassVar[Dict[str, Type[BaseComponent]]] = {}
[文档]
@classmethod
def create(cls, name: str, /, **kwargs) -> BaseComponent:
"""Create a class instance using the given name and kwargs."""
try:
component_cls = cls._registry[name]
except KeyError as e:
raise ValueError(f"Unknown component '{name}'") from e
if issubclass(component_cls, BaseModel):
return component_cls.model_validate(kwargs)
return component_cls(**kwargs)
[文档]
@classmethod
def register(cls, name: str, component_cls: Type[BaseComponent]) -> None:
"""Register a class into the factory"""
if name in cls._registry:
raise ValueError(f"Component name '{name}' already registered")
cls._registry[name] = component_cls
[文档]
@classmethod
def is_registered(cls, name: str) -> bool:
return name in cls._registry
[文档]
class BaseComponent(BaseModel):
model_config = ConfigDict(validate_assignment=True)
def __init_subclass__(cls, **kwargs: Unpack[ConfigDict]):
"""For any module that inherit BaseComponent will automatically register the class into the factory"""
super().__init_subclass__(**kwargs)
name = getattr(cls, "__component_name__", None) or cls.__name__
ComponentFactory.register(name, cls)
[文档]
class FactoryTypeAdapter:
[文档]
@classmethod
def __get_pydantic_core_schema__(
cls,
source_type,
handler
) -> core_schema.CoreSchema:
def validate_from_dict(data: Dict[str, Any]) -> BaseComponent:
if not isinstance(data, dict):
raise TypeError("Input must be a dict")
class_name = data.pop("__class_name__", None)
if not class_name:
raise ValueError(f"Input must contain key '__class_name__', got {data.keys()}")
return ComponentFactory.create(class_name, **data) # type: ignore
def serialize_to_dict(instance: BaseComponent) -> Dict[str, Any]:
class_name = instance.__class__.__name__
instance_data = instance.model_dump()
return {"__class_name__": class_name, **instance_data}
from_dict_schema = core_schema.chain_schema([
core_schema.dict_schema(),
core_schema.no_info_plain_validator_function(validate_from_dict),
])
return core_schema.json_or_python_schema(
json_schema=from_dict_schema,
python_schema=core_schema.union_schema([
core_schema.is_instance_schema(BaseComponent),
from_dict_schema,
]),
serialization=core_schema.plain_serializer_function_ser_schema(
serialize_to_dict,
when_used='unless-none'
),
)