diff --git a/nonebot_plugin_llmchat/__init__.py b/nonebot_plugin_llmchat/__init__.py index 79290a1..ab07224 100755 --- a/nonebot_plugin_llmchat/__init__.py +++ b/nonebot_plugin_llmchat/__init__.py @@ -279,7 +279,7 @@ async def process_messages(group_id: int): event = await state.queue.get() logger.debug(f"从队列获取消息 群号:{group_id} 消息ID:{event.message_id}") past_events_snapshot = [] - mcp_client = MCPClient(plugin_config.mcp_servers) + mcp_client = MCPClient.get_instance(plugin_config.mcp_servers) try: systemPrompt = f""" 我想要你帮我在群聊中闲聊,大家一般叫你{"、".join(list(driver.config.nickname))},我将会在后面的信息中告诉你每条群聊信息的发送者和发送时间,你可以直接称呼发送者为他对应的昵称。 @@ -349,7 +349,6 @@ async def process_messages(group_id: int): } if preset.support_mcp: - await mcp_client.connect_to_servers() available_tools = await mcp_client.get_available_tools() client_config["tools"] = available_tools @@ -455,7 +454,8 @@ async def process_messages(group_id: int): finally: state.processing = False state.queue.task_done() - await mcp_client.cleanup() + # 不再需要每次都清理MCPClient,因为它现在是单例 + # await mcp_client.cleanup() # 预设切换命令 @@ -621,3 +621,5 @@ async def init_plugin(): async def cleanup_plugin(): logger.info("插件关闭清理") await save_state() + # 销毁MCPClient单例 + await MCPClient.destroy_instance() diff --git a/nonebot_plugin_llmchat/mcpclient.py b/nonebot_plugin_llmchat/mcpclient.py index c3f3224..d0bc80e 100644 --- a/nonebot_plugin_llmchat/mcpclient.py +++ b/nonebot_plugin_llmchat/mcpclient.py @@ -10,12 +10,46 @@ from .config import MCPServerConfig class MCPClient: - def __init__(self, server_config: dict[str, MCPServerConfig]): - logger.info(f"正在初始化MCPClient,共有{len(server_config)}个服务器配置") + _instance = None + _initialized = False + + def __new__(cls, server_config: dict[str, MCPServerConfig] | None = None): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self, server_config: dict[str, MCPServerConfig] | None = None): + if self._initialized: + return + + if server_config is None: + raise ValueError("server_config must be provided for first initialization") + + logger.info(f"正在初始化MCPClient单例,共有{len(server_config)}个服务器配置") self.server_config = server_config self.sessions = {} self.exit_stack = AsyncExitStack() - logger.debug("MCPClient初始化成功") + # 添加工具列表缓存 + self._tools_cache: list | None = None + self._cache_initialized = False + self._initialized = True + logger.debug("MCPClient单例初始化成功") + + @classmethod + def get_instance(cls, server_config: dict[str, MCPServerConfig] | None = None): + """获取MCPClient实例""" + if cls._instance is None: + if server_config is None: + raise ValueError("server_config must be provided for first initialization") + cls._instance = cls(server_config) + return cls._instance + + @classmethod + def instance(cls): + """快速获取已初始化的MCPClient实例,如果未初始化则抛出异常""" + if cls._instance is None: + raise RuntimeError("MCPClient has not been initialized. Call get_instance() first.") + return cls._instance async def connect_to_servers(self): logger.info(f"开始连接{len(self.server_config)}个MCP服务器") @@ -38,47 +72,113 @@ class MCPClient: logger.info(f"已成功连接到MCP服务器[{server_name}]") + def _create_session_context(self, server_name: str): + """创建临时会话的异步上下文管理器""" + config = self.server_config[server_name] + + class SessionContext: + def __init__(self): + self.session = None + self.exit_stack = AsyncExitStack() + + async def __aenter__(self): + if config.url: + transport = await self.exit_stack.enter_async_context( + sse_client(url=config.url, headers=config.headers) + ) + elif config.command: + transport = await self.exit_stack.enter_async_context( + stdio_client(StdioServerParameters(**config.model_dump())) + ) + else: + raise ValueError("Server config must have either url or command") + + read, write = transport + self.session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + await self.session.initialize() + return self.session + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.exit_stack.aclose() + + return SessionContext() + async def get_available_tools(self): - logger.info(f"正在从{len(self.sessions)}个已连接的服务器获取可用工具") + """获取可用工具列表,使用缓存机制""" + if self._tools_cache is not None: + logger.debug("返回缓存的工具列表") + return self._tools_cache + + logger.info(f"初始化工具列表缓存,需要连接{len(self.server_config)}个服务器") available_tools = [] - for server_name, session in self.sessions.items(): - logger.debug(f"正在列出服务器[{server_name}]中的工具") - response = await session.list_tools() - tools = response.tools - logger.debug(f"在服务器[{server_name}]中找到{len(tools)}个工具") + for server_name in self.server_config.keys(): + logger.debug(f"正在从服务器[{server_name}]获取工具列表") + async with self._create_session_context(server_name) as session: + response = await session.list_tools() + tools = response.tools + logger.debug(f"在服务器[{server_name}]中找到{len(tools)}个工具") - available_tools.extend( - { - "type": "function", - "function": { - "name": f"{server_name}___{tool.name}", - "description": tool.description, - "parameters": tool.inputSchema, - }, - } - for tool in tools - ) + available_tools.extend( + { + "type": "function", + "function": { + "name": f"{server_name}___{tool.name}", + "description": tool.description, + "parameters": tool.inputSchema, + }, + } + for tool in tools + ) + + # 缓存工具列表 + self._tools_cache = available_tools + self._cache_initialized = True + logger.info(f"工具列表缓存完成,共缓存{len(available_tools)}个工具") return available_tools async def call_tool(self, tool_name: str, tool_args: dict): + """按需连接调用工具,调用后立即断开""" server_name, real_tool_name = tool_name.split("___") - logger.info(f"正在服务器[{server_name}]上调用工具[{real_tool_name}]") - session = self.sessions[server_name] - try: - response = await asyncio.wait_for(session.call_tool(real_tool_name, tool_args), timeout=30) - except asyncio.TimeoutError: - logger.error(f"调用工具[{real_tool_name}]超时") - return f"调用工具[{real_tool_name}]超时" - logger.debug(f"工具[{real_tool_name}]调用完成,响应: {response}") - return response.content + logger.info(f"按需连接到服务器[{server_name}]调用工具[{real_tool_name}]") + + async with self._create_session_context(server_name) as session: + try: + response = await asyncio.wait_for(session.call_tool(real_tool_name, tool_args), timeout=30) + logger.debug(f"工具[{real_tool_name}]调用完成,响应: {response}") + return response.content + except asyncio.TimeoutError: + logger.error(f"调用工具[{real_tool_name}]超时") + return f"调用工具[{real_tool_name}]超时" def get_friendly_name(self, tool_name: str): logger.debug(tool_name) server_name, real_tool_name = tool_name.split("___") return (self.server_config[server_name].friendly_name or server_name) + " - " + real_tool_name + def clear_tools_cache(self): + """清除工具列表缓存""" + logger.info("清除工具列表缓存") + self._tools_cache = None + self._cache_initialized = False + async def cleanup(self): + """清理资源(不销毁单例)""" logger.debug("正在清理MCPClient资源") + # 只清除缓存,不销毁单例 + # self.clear_tools_cache() # 保留缓存,避免重复获取工具列表 await self.exit_stack.aclose() + # 重新初始化exit_stack以便后续使用 + self.exit_stack = AsyncExitStack() logger.debug("MCPClient资源清理完成") + + @classmethod + async def destroy_instance(cls): + """完全销毁单例实例(仅在应用关闭时使用)""" + if cls._instance is not None: + logger.info("销毁MCPClient单例") + await cls._instance.cleanup() + cls._instance.clear_tools_cache() + cls._instance = None + cls._initialized = False + logger.debug("MCPClient单例已销毁")