diff --git a/nonebot_plugin_llmchat/__init__.py b/nonebot_plugin_llmchat/__init__.py index 92a3060..cbca006 100755 --- a/nonebot_plugin_llmchat/__init__.py +++ b/nonebot_plugin_llmchat/__init__.py @@ -395,12 +395,14 @@ async def process_messages(context_id: int, is_group: bool = True): systemPrompt = "\n".join(system_lines) if preset.support_mcp: - systemPrompt += "你也可以使用一些工具,下面是关于这些工具的额外说明:\n" + systemPrompt += "\n你也可以使用一些工具,下面是关于这些工具的额外说明:\n" for mcp_name, mcp_config in plugin_config.mcp_servers.items(): - if mcp_config.addtional_prompt: - systemPrompt += f"{mcp_name}:{mcp_config.addtional_prompt}" + if mcp_config.additional_prompt: + systemPrompt += f"{mcp_name}:{mcp_config.additional_prompt}" systemPrompt += "\n" + logger.debug(f"构建系统提示词:\n{systemPrompt}") + messages: list[ChatCompletionMessageParam] = [ {"role": "system", "content": systemPrompt} ] @@ -478,8 +480,23 @@ async def process_messages(context_id: int, is_group: bool = True): new_messages.append(llm_reply) for tool_call in message.tool_calls: + logger.debug(f"处理工具调用:{tool_call.function.name} 参数:{tool_call.function.arguments}") + tool_name = tool_call.function.name - tool_args = json.loads(tool_call.function.arguments) + try: + tool_args = json.loads(tool_call.function.arguments) + except (json.JSONDecodeError, TypeError, ValueError) as e: + error_message = ( + f"工具调用参数格式错误,无法解析 {tool_name} 的 arguments: {e!s}. " + f"原始参数: {tool_call.function.arguments}" + ) + logger.warning(error_message) + new_messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": error_message, + }) + continue # 发送工具调用提示 await handler.send(Message(f"正在使用{mcp_client.get_friendly_name(tool_name)}")) diff --git a/nonebot_plugin_llmchat/config.py b/nonebot_plugin_llmchat/config.py index 8d94ec4..06bc55d 100755 --- a/nonebot_plugin_llmchat/config.py +++ b/nonebot_plugin_llmchat/config.py @@ -29,7 +29,7 @@ class MCPServerConfig(BaseModel): # 额外字段 friendly_name: str | None = Field(None, description="MCP服务器友好名称") - addtional_prompt: str | None = Field(None, description="额外提示词") + additional_prompt: str | None = Field(None, description="额外提示词") class ScopedConfig(BaseModel): """LLM Chat Plugin配置""" diff --git a/nonebot_plugin_llmchat/mcpclient.py b/nonebot_plugin_llmchat/mcpclient.py index 8861dd9..5dc6b6c 100644 --- a/nonebot_plugin_llmchat/mcpclient.py +++ b/nonebot_plugin_llmchat/mcpclient.py @@ -1,5 +1,7 @@ import asyncio from contextlib import AsyncExitStack +from time import monotonic +from typing import Any, cast from mcp import ClientSession, StdioServerParameters from mcp.client.sse import sse_client @@ -13,6 +15,8 @@ from .onebottools import OneBotTools class MCPClient: _instance = None _initialized = False + _SESSION_TTL_SECONDS = 600 + _SESSION_CLEANUP_INTERVAL_SECONDS = 60 def __new__(cls, server_config: dict[str, MCPServerConfig] | None = None): if cls._instance is None: @@ -30,6 +34,10 @@ class MCPClient: self.server_config = server_config self.sessions = {} self.exit_stack = AsyncExitStack() + self._session_exit_stacks: dict[str, AsyncExitStack] = {} + self._session_last_used: dict[str, float] = {} + self._session_lock = asyncio.Lock() + self._session_cleanup_task: asyncio.Task | None = None # 添加工具列表缓存 self._tools_cache: list | None = None self._cache_initialized = False @@ -55,80 +63,115 @@ class MCPClient: return cls._instance async def connect_to_servers(self): + await self._ensure_cleanup_task() logger.info(f"开始连接{len(self.server_config)}个MCP服务器") - for server_name, config in self.server_config.items(): + for server_name in self.server_config: logger.debug(f"正在连接服务器[{server_name}]") - if config.url: - sse_transport = await self.exit_stack.enter_async_context(sse_client(url=config.url, headers=config.headers)) - read, write = sse_transport - self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write)) - await self.sessions[server_name].initialize() - elif config.command: - stdio_transport = await self.exit_stack.enter_async_context( - stdio_client(StdioServerParameters(**config.model_dump())) - ) - read, write = stdio_transport - self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write)) - await self.sessions[server_name].initialize() - else: - raise ValueError("Server config must have either url or command") - + await self._get_or_create_session(server_name) logger.info(f"已成功连接到MCP服务器[{server_name}]") - def _create_session_context(self, server_name: str): - """创建临时会话的异步上下文管理器""" + async def _create_server_session(self, server_name: str) -> tuple[ClientSession, AsyncExitStack]: + """创建并初始化一个新的服务器会话。""" config = self.server_config[server_name] + session_stack = AsyncExitStack() + if config.url: + transport = await session_stack.enter_async_context( + sse_client(url=config.url, headers=config.headers) + ) + elif config.command: + transport = await session_stack.enter_async_context( + cast(Any, stdio_client(StdioServerParameters(**config.model_dump()))) + ) + else: + raise ValueError("Server config must have either url or command") - class SessionContext: - def __init__(self): - self.session = None - self.exit_stack = AsyncExitStack() + read, write = transport + session = await session_stack.enter_async_context(ClientSession(read, write)) + await session.initialize() + return session, session_stack - async def __aenter__(self): - if config.url: - transport = await self.exit_stack.enter_async_context( - sse_client(url=config.url, headers=config.headers) - ) - elif config.command: - transport = await self.exit_stack.enter_async_context( - stdio_client(StdioServerParameters(**config.model_dump())) - ) - else: - raise ValueError("Server config must have either url or command") + async def _close_server_session(self, server_name: str): + """关闭指定服务器会话。""" + session_stack = self._session_exit_stacks.pop(server_name, None) + self.sessions.pop(server_name, None) + self._session_last_used.pop(server_name, None) - read, write = transport - self.session = await self.exit_stack.enter_async_context(ClientSession(read, write)) - await self.session.initialize() - return self.session + if session_stack is not None: + await session_stack.aclose() - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.exit_stack.aclose() + async def _get_or_create_session(self, server_name: str) -> ClientSession: + """获取可复用会话;若不存在或已过期则新建。""" + now = monotonic() + async with self._session_lock: + last_used = self._session_last_used.get(server_name) + session = self.sessions.get(server_name) - return SessionContext() + # 空闲超过阈值则销毁重建 + if session is not None and last_used is not None: + if now - last_used > self._SESSION_TTL_SECONDS: + logger.info(f"服务器[{server_name}]会话空闲超过10分钟,重新创建") + await self._close_server_session(server_name) + session = None + + if session is None: + session, session_stack = await self._create_server_session(server_name) + self.sessions[server_name] = session + self._session_exit_stacks[server_name] = session_stack + + self._session_last_used[server_name] = now + return self.sessions[server_name] + + async def _cleanup_expired_sessions(self): + """回收空闲过期会话。""" + now = monotonic() + async with self._session_lock: + expired_servers = [ + server_name + for server_name, last_used in self._session_last_used.items() + if now - last_used > self._SESSION_TTL_SECONDS + ] + + for server_name in expired_servers: + logger.info(f"回收空闲MCP会话[{server_name}]") + await self._close_server_session(server_name) + + async def _session_cleanup_loop(self): + try: + while True: + await asyncio.sleep(self._SESSION_CLEANUP_INTERVAL_SECONDS) + await self._cleanup_expired_sessions() + except asyncio.CancelledError: + logger.debug("MCP会话清理任务已取消") + raise + + async def _ensure_cleanup_task(self): + if self._session_cleanup_task is None or self._session_cleanup_task.done(): + self._session_cleanup_task = asyncio.create_task(self._session_cleanup_loop()) async def init_tools_cache(self): """初始化工具列表缓存""" if not self._cache_initialized: + await self._ensure_cleanup_task() available_tools = [] logger.info(f"初始化工具列表缓存,需要连接{len(self.server_config)}个服务器") for server_name in self.server_config.keys(): logger.debug(f"正在从服务器[{server_name}]获取工具列表") - async with self._create_session_context(server_name) as session: - response = await session.list_tools() - tools = response.tools - logger.debug(f"在服务器[{server_name}]中找到{len(tools)}个工具") + session = await self._get_or_create_session(server_name) + response = await session.list_tools() + tools = response.tools + logger.debug(f"在服务器[{server_name}]中找到{len(tools)}个工具") - available_tools.extend( - { - "type": "function", - "function": { - "name": f"mcp__{server_name}__{tool.name}", - "description": tool.description, - "parameters": tool.inputSchema, - }, - } - for tool in tools - ) + available_tools.extend( + { + "type": "function", + "function": { + "name": f"mcp__{server_name}__{tool.name}", + "description": tool.description, + "parameters": tool.inputSchema, + }, + } + for tool in tools + ) # 缓存工具列表 self._tools_cache = available_tools @@ -149,7 +192,7 @@ class MCPClient: return available_tools async def call_tool(self, tool_name: str, tool_args: dict, group_id: int | None = None, bot_id: str | None = None): - """按需连接调用工具,调用后立即断开""" + """按需调用工具,MCP会话会在10分钟空闲后自动回收。""" # 检查是否是OneBot内置工具 if tool_name.startswith("ob__"): if group_id is None or bot_id is None: @@ -168,14 +211,20 @@ class MCPClient: real_tool_name = parts[2] logger.info(f"按需连接到服务器[{server_name}]调用工具[{real_tool_name}]") - async with self._create_session_context(server_name) as session: - try: - response = await asyncio.wait_for(session.call_tool(real_tool_name, tool_args), timeout=30) - logger.debug(f"工具[{real_tool_name}]调用完成,响应: {response}") - return response.content - except asyncio.TimeoutError: - logger.error(f"调用工具[{real_tool_name}]超时") - return f"调用工具[{real_tool_name}]超时" + try: + await self._ensure_cleanup_task() + session = await self._get_or_create_session(server_name) + response = await asyncio.wait_for(session.call_tool(real_tool_name, tool_args), timeout=30) + logger.debug(f"工具[{real_tool_name}]调用完成,响应: {response}") + return response.content + except asyncio.TimeoutError: + logger.error(f"调用工具[{real_tool_name}]超时") + return f"调用工具[{real_tool_name}]超时" + except (RuntimeError, ValueError, TypeError, OSError, ConnectionError) as e: + logger.opt(exception=e).error(f"调用工具[{real_tool_name}]失败,准备重置会话") + async with self._session_lock: + await self._close_server_session(server_name) + return f"调用工具[{real_tool_name}]失败: {e!s}" # 未知工具类型 return f"未知的工具类型: {tool_name}" @@ -211,6 +260,19 @@ class MCPClient: logger.debug("正在清理MCPClient资源") # 只清除缓存,不销毁单例 # self.clear_tools_cache() # 保留缓存,避免重复获取工具列表 + + if self._session_cleanup_task is not None: + self._session_cleanup_task.cancel() + try: + await self._session_cleanup_task + except asyncio.CancelledError: + pass + self._session_cleanup_task = None + + async with self._session_lock: + for server_name in list(self.sessions.keys()): + await self._close_server_session(server_name) + await self.exit_stack.aclose() # 重新初始化exit_stack以便后续使用 self.exit_stack = AsyncExitStack()