evofabric.core.tool._tool_utils 源代码

# -*- coding: utf-8 -*-
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
import copy
import fnmatch
import functools
import importlib.util
import inspect
import os
import types
import uuid
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import Annotated, Any, AsyncGenerator, Callable, get_args, get_origin, List, Optional

import mcp
from docstring_parser import Docstring, parse
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from pydantic import ConfigDict, create_model, Field

from ..typing import McpServerLink, SseLink, StdioLink, StreamableHttpLink


def load_function(
        file_path: str,
        include_patterns: List[str] = None,
        exclude_patterns: List[str] = None
):
    """
    Import all functions from a given Python file and return them as a dictionary.

    This function dynamically loads a Python file as a module, then extracts all module-level
    functions (excluding classes and methods) that pass the include/exclude pattern filters.

    Args:
        file_path (str):
            Absolute or relative path to the target Python file. Must be a valid .py file.
        include_patterns (List[str], optional):
            List of Unix shell-style wildcards (e.g., 'test_*') that function names must match ALL patterns.
            If None, no include filtering is applied. Defaults to None.
        exclude_patterns (List[str], optional):
            List of Unix shell-style wildcards (e.g., '_*') that function names must NOT match ANY pattern.
            If None, no exclude filtering is applied. Defaults to None.

    Returns:
        dict:
            Dictionary mapping function names (str) to their callable objects. Only includes:
            - Module-level functions (not class methods)
            - Functions passing the include_patterns (if specified)
            - Functions NOT excluded by exclude_patterns (if specified)

    Raises:
        ImportError:
            If the file cannot be loaded as a Python module (invalid syntax, missing path, etc.)
        FileNotFoundError:
            If file_path points to a non-existent file (raised by underlying filesystem calls)
    """
    module_name = os.path.splitext(os.path.basename(file_path))[0]

    spec = importlib.util.spec_from_file_location(module_name, file_path)
    if spec is None:
        raise ImportError(f"{file_path}")

    module = importlib.util.module_from_spec(spec)

    spec.loader.exec_module(module)

    functions = {}
    for name, obj in module.__dict__.items():
        if (include_patterns is not None) and (not all(fnmatch.fnmatch(name, pattern) for pattern in include_patterns)):
            continue

        if (exclude_patterns is not None) and (any(fnmatch.fnmatch(name, pattern) for pattern in exclude_patterns)):
            continue

        if callable(obj) and isinstance(obj, types.FunctionType):
            functions[name] = obj

    return functions


def _safe_peel_partial(function):
    """Peel off partial wrapper"""
    while isinstance(function, functools.partial):
        function = function.func
    return function


def _parse_docstring(
        function: Callable,
        name: Optional[str] = None,
        description: Optional[str] = None,
        include_long_description: bool = True,
):
    """
    Parses the docstring of a function to extract structured documentation metadata.

    Args:
        function: The function whose docstring is to be parsed.
        name (Optional[str]): Custom name to use for the function. If None, uses `function.__name__`.
        description (Optional[str]): Custom description to override the docstring content.
                                    If None, builds description from docstring.
        include_long_description (bool): Whether to include the long description from the docstring.
                                         Defaults to True.

    Returns:
        tuple: (parsed_docstring, function_name, formatted_description)
            - parsed_docstring: Parsed docstring object from `docstring_parser.parse()`
            - function_name: Name of the function (either specified or inferred)
            - formatted_description: Formatted description string built from docstring or provided
    """
    function = _safe_peel_partial(function)
    docs = parse(function.__doc__)

    if name is None:
        name = function.__name__

    if description is None:
        description_parts = []
        if docs.short_description:
            description_parts.append(docs.short_description)
        if docs.long_description and include_long_description:
            description_parts.append(docs.long_description)
    else:
        description_parts = [description]
    description = "\n\n".join(description_parts)

    return docs, name, description


def _is_inspect_empty(sth):
    """Checks if an object matches `inspect.Parameter.empty` sentinel value."""
    return sth == inspect.Parameter.empty


def _parse_func_params(
        function: Callable,
        docstring: Docstring = None,
        include_var_positional: bool = True,
        include_var_keyword: bool = True,
        exclude_params: Optional[List[str]] = None,
):
    """
    Parses function parameters using signature inspection and docstring metadata.

    Args:
        function: Target function for parameter inspection.
        docstring (Docstring): Parsed docstring object (from `docstring_parser.parse()`).
                              If None, will parse the function's docstring.
        include_var_positional (bool): Whether to include `*args` style parameters. Defaults to True.
        include_var_keyword (bool): Whether to include `**kwargs` style parameters. Defaults to True.
        exclude_params (Optional[List[str]]): Parameter names to exclude from processing.

    Returns:
        tuple: (parameter_fields, excluded_parameters_list)
            - parameter_fields: Dict mapping parameter names to (type, Field) tuples
            - excluded_parameters_list: Names of parameters actually excluded
    """

    def get_annotation(param, default_type=Any):
        return default_type if _is_inspect_empty(param.annotation) else param.annotation

    def create_field(description, default_value):
        return Field(default=default_value, description=description)

    excluded_param_list = []
    exclude_params = copy.deepcopy(exclude_params) or []
    sig = inspect.signature(function)

    # for partial function, will exclude bound params
    while isinstance(function, functools.partial):
        arguments: inspect.BoundArguments = sig.bind_partial(*function.args, **function.keywords)
        exclude_params.extend(arguments.arguments.keys())
        function = function.func

    is_class_method = function.__qualname__ and "." in function.__qualname__
    docstring = docstring or parse(function.__doc__)
    param_desp_map: dict[str, str] = {x.arg_name: x.description for x in docstring.params}

    fields = {}
    # iteratively define param field
    for pidx, (name, param) in enumerate(sig.parameters.items()):
        description = param_desp_map.get(name, None)
        default = param.default if not _is_inspect_empty(param.default) else ...
        annotation = get_annotation(param)

        if description is None and get_origin(annotation) is Annotated:
            # update description if argument define `Annotated`
            args = get_args(annotation)
            for arg in args[1:]:
                if isinstance(arg, str):
                    description = arg
                    break

        if is_class_method and pidx == 0 and name in {"cls", "self"}:
            # handle self and cls
            continue

        if name in exclude_params:
            # handle exclude params
            excluded_param_list.append(name)
            continue

        if param.kind == param.VAR_POSITIONAL:
            if include_var_positional:
                # handle function(*args)
                fields[name] = (
                    list[annotation]
                    if not _is_inspect_empty(annotation)
                    else list,
                    create_field(description, None)
                )

        elif param.kind == param.VAR_KEYWORD:
            if include_var_keyword:
                # handle function(**kwargs)
                fields[name] = (
                    dict[str, annotation]
                    if not _is_inspect_empty(annotation)
                    else dict,
                    create_field(description, None)
                )

        else:
            # handle other params
            fields[name] = (
                annotation,
                create_field(description, default)
            )

    return fields, excluded_param_list


[文档] def parse_callable_schema( function: Callable, name: Optional[str] = None, description: Optional[str] = None, include_long_description: bool = True, include_var_positional: bool = True, include_var_keyword: bool = True, exclude_params: Optional[List[str]] = None, ): """ Generates a JSON schema representation of a callable function for LLM tool usage. This function transforms a Python callable into a structured schema compatible with LLM tool/function calling formats. It parses the function's signature, docstring, and parameter defaults to create a complete description of the function's interface. Support function type: - python function - class method - @classmethod - @staticmethod - partial(function / class method) - lambda Args: function: Callable function to convert into a schema. name (Optional[str]): Custom function name. Defaults to function's __name__. description (Optional[str]): Custom description. Defaults to docstring summary. include_long_description (bool): Whether to include long description from docstring. include_var_positional (bool): Whether to include *args parameters. include_var_keyword (bool): Whether to include **kwargs parameters. exclude_params (Optional[List[str]]): Parameter names to exclude. Returns: tuple: (json_schema, excluded_parameters_list) - json_schema: Dictionary with OpenAI-compatible function schema - excluded_parameters_list: Names of parameters actually excluded """ docstring, name, description = _parse_docstring( function, name, description, include_long_description ) fields_info, excluded_params = _parse_func_params( function, docstring, include_var_positional, include_var_keyword, exclude_params, ) model_name = f"pyfunc_{name}_{uuid.uuid4().hex[:4]}" pydantic_model = create_model( model_name, __config__=ConfigDict(arbitrary_types_allowed=True), **fields_info, ) param_schema = pydantic_model.model_json_schema() for _, param in param_schema["properties"].items(): param.pop("title", "") full_json_schema: dict = { "type": "function", "function": { "name": name, "description": description, "parameters": param_schema, }, } full_json_schema["function"]["parameters"].pop("title", None) return full_json_schema, excluded_params
def parse_mcp_tool_function( mcp_tool: mcp.types.Tool, server_name: str = None, include_long_description: bool = True, ) -> dict: """Converts MCP tool schema to OpenAI-compatible function call format. This function: 1. Parses the docstring from the tool's description to extract structured metadata 2. Combines function description components into a coherent narrative 3. Updates parameter schemas with human-readable descriptions from docstring 4. Generates a function-call-ready JSON schema with proper naming conventions Args: mcp_tool: Input tool schema containing the core function definition and raw docstring (triple-quoted content). server_name: Optional server identifier to prefix function names (e.g., "myserver_calculate" when server_name="myserver"). include_long_description: Controls whether detailed documentation sections should be included in the function's main description (default: True). Returns: A dictionary representing an OpenAI function call schema with: - Type specification: "function" - Function metadata including: * Name: Prefixed with server_name if specified * Parameters: JSON schema with docstring-resolved descriptions * Description: Consolidated short/long description Note: The function relies on the docstring containing: - Short description (one-line summary) - Long description (detailed documentation, optional) - Parameter documentation (arg_name:description pairs) """ docstring = parse(mcp_tool.description) params_docstring = { param.arg_name: param.description for param in docstring.params } # Function description description_parts = [] if docstring.short_description is not None: description_parts.append(docstring.short_description) if include_long_description and docstring.long_description is not None: description_parts.append(docstring.long_description) description = "\n\n".join(description_parts) params_json_schema = copy.deepcopy(mcp_tool.inputSchema['properties']) for name, info in params_json_schema.items(): params_json_schema[name]['description'] = params_docstring.get(name, None) func_json_schema: dict = { "type": "function", "function": { "name": mcp_tool.name if server_name is None else server_name + "_" + mcp_tool.name, "parameters": params_json_schema, "description": description }, } return func_json_schema @asynccontextmanager async def create_mcp_session( server_link: McpServerLink, sampling_callback=None, ) -> AsyncGenerator[ClientSession, None]: """ Create an MCP client session for the given server parameters. Yields: An initialized ClientSession. """ # Shared setup for all transports read_timeout = timedelta(seconds=_get_read_timeout_seconds(server_link)) # Select and invoke the appropriate transport client if isinstance(server_link, StdioLink): client_ctx = stdio_client(server_link) elif isinstance(server_link, SseLink): client_ctx = sse_client(**server_link.model_dump(exclude={"type"})) elif isinstance(server_link, StreamableHttpLink): params = server_link.model_dump(exclude={"type"}) params["timeout"] = timedelta(seconds=params["timeout"]) params["sse_read_timeout"] = timedelta( seconds=params["sse_read_timeout"]) client_ctx = streamablehttp_client(**params) else: raise NotImplementedError( f"Unsupported server params type: {type(server_link)}") # Enter transport context and extract streams async with client_ctx as streams: read, write = streams[0], streams[1] # Create and yield the ClientSession async with ClientSession( read_stream=read, write_stream=write, read_timeout_seconds=read_timeout, sampling_callback=sampling_callback, ) as session: yield session def _get_read_timeout_seconds(server_link: McpServerLink, default=30.) -> float: """Extract the appropriate read timeout in seconds based on server params type.""" if isinstance(server_link, StdioLink): return server_link.read_time_out elif isinstance(server_link, (SseLink, StreamableHttpLink)): return getattr(server_link, "sse_read_timeout", 300.0) else: return default