nonebot-plugin-llmchat/nonebot_plugin_llmchat/mcpclient.py

228 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
from contextlib import AsyncExitStack
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from nonebot import logger
from .config import MCPServerConfig
from .onebottools import OneBotTools
class MCPClient:
_instance = None
_initialized = False
def __new__(cls, server_config: dict[str, MCPServerConfig] | None = None):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, server_config: dict[str, MCPServerConfig] | None = None):
if self._initialized:
return
if server_config is None:
raise ValueError("server_config must be provided for first initialization")
logger.info(f"正在初始化MCPClient单例共有{len(server_config)}个服务器配置")
self.server_config = server_config
self.sessions = {}
self.exit_stack = AsyncExitStack()
# 添加工具列表缓存
self._tools_cache: list | None = None
self._cache_initialized = False
# 初始化OneBot工具
self.onebot_tools = OneBotTools()
self._initialized = True
logger.debug("MCPClient单例初始化成功")
@classmethod
def get_instance(cls, server_config: dict[str, MCPServerConfig] | None = None):
"""获取MCPClient实例"""
if cls._instance is None:
if server_config is None:
raise ValueError("server_config must be provided for first initialization")
cls._instance = cls(server_config)
return cls._instance
@classmethod
def instance(cls):
"""快速获取已初始化的MCPClient实例如果未初始化则抛出异常"""
if cls._instance is None:
raise RuntimeError("MCPClient has not been initialized. Call get_instance() first.")
return cls._instance
async def connect_to_servers(self):
logger.info(f"开始连接{len(self.server_config)}个MCP服务器")
for server_name, config in self.server_config.items():
logger.debug(f"正在连接服务器[{server_name}]")
if config.url:
sse_transport = await self.exit_stack.enter_async_context(sse_client(url=config.url, headers=config.headers))
read, write = sse_transport
self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write))
await self.sessions[server_name].initialize()
elif config.command:
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(StdioServerParameters(**config.model_dump()))
)
read, write = stdio_transport
self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write))
await self.sessions[server_name].initialize()
else:
raise ValueError("Server config must have either url or command")
logger.info(f"已成功连接到MCP服务器[{server_name}]")
def _create_session_context(self, server_name: str):
"""创建临时会话的异步上下文管理器"""
config = self.server_config[server_name]
class SessionContext:
def __init__(self):
self.session = None
self.exit_stack = AsyncExitStack()
async def __aenter__(self):
if config.url:
transport = await self.exit_stack.enter_async_context(
sse_client(url=config.url, headers=config.headers)
)
elif config.command:
transport = await self.exit_stack.enter_async_context(
stdio_client(StdioServerParameters(**config.model_dump()))
)
else:
raise ValueError("Server config must have either url or command")
read, write = transport
self.session = await self.exit_stack.enter_async_context(ClientSession(read, write))
await self.session.initialize()
return self.session
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.exit_stack.aclose()
return SessionContext()
async def init_tools_cache(self):
"""初始化工具列表缓存"""
if not self._cache_initialized:
available_tools = []
logger.info(f"初始化工具列表缓存,需要连接{len(self.server_config)}个服务器")
for server_name in self.server_config.keys():
logger.debug(f"正在从服务器[{server_name}]获取工具列表")
async with self._create_session_context(server_name) as session:
response = await session.list_tools()
tools = response.tools
logger.debug(f"在服务器[{server_name}]中找到{len(tools)}个工具")
available_tools.extend(
{
"type": "function",
"function": {
"name": f"mcp__{server_name}__{tool.name}",
"description": tool.description,
"parameters": tool.inputSchema,
},
}
for tool in tools
)
# 缓存工具列表
self._tools_cache = available_tools
self._cache_initialized = True
logger.info(f"工具列表缓存完成,共缓存{len(available_tools)}个工具")
async def get_available_tools(self, is_group: bool):
"""获取可用工具列表,使用缓存机制"""
await self.init_tools_cache()
available_tools = self._tools_cache.copy() if self._tools_cache else []
if is_group:
# 群聊场景包含OneBot工具和MCP工具
available_tools.extend(self.onebot_tools.get_available_tools())
logger.debug(f"获取可用工具列表,共{len(available_tools)}个工具")
return available_tools
async def call_tool(self, tool_name: str, tool_args: dict, group_id: int | None = None, bot_id: str | None = None):
"""按需连接调用工具,调用后立即断开"""
# 检查是否是OneBot内置工具
if tool_name.startswith("ob__"):
if group_id is None or bot_id is None:
return "QQ工具需要提供group_id和bot_id参数"
logger.info(f"调用OneBot工具[{tool_name}]")
return await self.onebot_tools.call_tool(tool_name, tool_args, group_id, bot_id)
# 检查是否是MCP工具
if tool_name.startswith("mcp__"):
# MCP工具处理mcp__server_name__tool_name
parts = tool_name.split("__")
if len(parts) != 3 or parts[0] != "mcp":
return f"MCP工具名称格式错误: {tool_name}"
server_name = parts[1]
real_tool_name = parts[2]
logger.info(f"按需连接到服务器[{server_name}]调用工具[{real_tool_name}]")
async with self._create_session_context(server_name) as session:
try:
response = await asyncio.wait_for(session.call_tool(real_tool_name, tool_args), timeout=30)
logger.debug(f"工具[{real_tool_name}]调用完成,响应: {response}")
return response.content
except asyncio.TimeoutError:
logger.error(f"调用工具[{real_tool_name}]超时")
return f"调用工具[{real_tool_name}]超时"
# 未知工具类型
return f"未知的工具类型: {tool_name}"
def get_friendly_name(self, tool_name: str):
logger.debug(tool_name)
# 检查是否是OneBot内置工具
if tool_name.startswith("ob__"):
return self.onebot_tools.get_friendly_name(tool_name)
# 检查是否是MCP工具
if tool_name.startswith("mcp__"):
# MCP工具处理mcp__server_name__tool_name
parts = tool_name.split("__")
if len(parts) != 3 or parts[0] != "mcp":
return tool_name # 格式错误时返回原名称
server_name = parts[1]
real_tool_name = parts[2]
return (self.server_config[server_name].friendly_name or server_name) + " - " + real_tool_name
# 未知工具类型,返回原名称
return tool_name
def clear_tools_cache(self):
"""清除工具列表缓存"""
logger.info("清除工具列表缓存")
self._tools_cache = None
self._cache_initialized = False
async def cleanup(self):
"""清理资源(不销毁单例)"""
logger.debug("正在清理MCPClient资源")
# 只清除缓存,不销毁单例
# self.clear_tools_cache() # 保留缓存,避免重复获取工具列表
await self.exit_stack.aclose()
# 重新初始化exit_stack以便后续使用
self.exit_stack = AsyncExitStack()
logger.debug("MCPClient资源清理完成")
@classmethod
async def destroy_instance(cls):
"""完全销毁单例实例(仅在应用关闭时使用)"""
if cls._instance is not None:
logger.info("销毁MCPClient单例")
await cls._instance.cleanup()
cls._instance.clear_tools_cache()
cls._instance = None
cls._initialized = False
logger.debug("MCPClient单例已销毁")