mirror of
https://github.com/FuQuan233/nonebot-plugin-llmchat.git
synced 2026-06-28 00:02:04 +00:00
🐛 优化MCPClient会话管理,添加会话过期清理机制和异常处理
This commit is contained in:
parent
802265ca22
commit
38af060cb2
1 changed files with 126 additions and 64 deletions
|
|
@ -1,5 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
|
from time import monotonic
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
from mcp import ClientSession, StdioServerParameters
|
from mcp import ClientSession, StdioServerParameters
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
|
|
@ -13,6 +15,8 @@ from .onebottools import OneBotTools
|
||||||
class MCPClient:
|
class MCPClient:
|
||||||
_instance = None
|
_instance = None
|
||||||
_initialized = False
|
_initialized = False
|
||||||
|
_SESSION_TTL_SECONDS = 600
|
||||||
|
_SESSION_CLEANUP_INTERVAL_SECONDS = 60
|
||||||
|
|
||||||
def __new__(cls, server_config: dict[str, MCPServerConfig] | None = None):
|
def __new__(cls, server_config: dict[str, MCPServerConfig] | None = None):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
|
|
@ -30,6 +34,10 @@ class MCPClient:
|
||||||
self.server_config = server_config
|
self.server_config = server_config
|
||||||
self.sessions = {}
|
self.sessions = {}
|
||||||
self.exit_stack = AsyncExitStack()
|
self.exit_stack = AsyncExitStack()
|
||||||
|
self._session_exit_stacks: dict[str, AsyncExitStack] = {}
|
||||||
|
self._session_last_used: dict[str, float] = {}
|
||||||
|
self._session_lock = asyncio.Lock()
|
||||||
|
self._session_cleanup_task: asyncio.Task | None = None
|
||||||
# 添加工具列表缓存
|
# 添加工具列表缓存
|
||||||
self._tools_cache: list | None = None
|
self._tools_cache: list | None = None
|
||||||
self._cache_initialized = False
|
self._cache_initialized = False
|
||||||
|
|
@ -55,80 +63,115 @@ class MCPClient:
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
async def connect_to_servers(self):
|
async def connect_to_servers(self):
|
||||||
|
await self._ensure_cleanup_task()
|
||||||
logger.info(f"开始连接{len(self.server_config)}个MCP服务器")
|
logger.info(f"开始连接{len(self.server_config)}个MCP服务器")
|
||||||
for server_name, config in self.server_config.items():
|
for server_name in self.server_config:
|
||||||
logger.debug(f"正在连接服务器[{server_name}]")
|
logger.debug(f"正在连接服务器[{server_name}]")
|
||||||
if config.url:
|
await self._get_or_create_session(server_name)
|
||||||
sse_transport = await self.exit_stack.enter_async_context(sse_client(url=config.url, headers=config.headers))
|
|
||||||
read, write = sse_transport
|
|
||||||
self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
|
||||||
await self.sessions[server_name].initialize()
|
|
||||||
elif config.command:
|
|
||||||
stdio_transport = await self.exit_stack.enter_async_context(
|
|
||||||
stdio_client(StdioServerParameters(**config.model_dump()))
|
|
||||||
)
|
|
||||||
read, write = stdio_transport
|
|
||||||
self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
|
||||||
await self.sessions[server_name].initialize()
|
|
||||||
else:
|
|
||||||
raise ValueError("Server config must have either url or command")
|
|
||||||
|
|
||||||
logger.info(f"已成功连接到MCP服务器[{server_name}]")
|
logger.info(f"已成功连接到MCP服务器[{server_name}]")
|
||||||
|
|
||||||
def _create_session_context(self, server_name: str):
|
async def _create_server_session(self, server_name: str) -> tuple[ClientSession, AsyncExitStack]:
|
||||||
"""创建临时会话的异步上下文管理器"""
|
"""创建并初始化一个新的服务器会话。"""
|
||||||
config = self.server_config[server_name]
|
config = self.server_config[server_name]
|
||||||
|
session_stack = AsyncExitStack()
|
||||||
|
if config.url:
|
||||||
|
transport = await session_stack.enter_async_context(
|
||||||
|
sse_client(url=config.url, headers=config.headers)
|
||||||
|
)
|
||||||
|
elif config.command:
|
||||||
|
transport = await session_stack.enter_async_context(
|
||||||
|
cast(Any, stdio_client(StdioServerParameters(**config.model_dump())))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Server config must have either url or command")
|
||||||
|
|
||||||
class SessionContext:
|
read, write = transport
|
||||||
def __init__(self):
|
session = await session_stack.enter_async_context(ClientSession(read, write))
|
||||||
self.session = None
|
await session.initialize()
|
||||||
self.exit_stack = AsyncExitStack()
|
return session, session_stack
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def _close_server_session(self, server_name: str):
|
||||||
if config.url:
|
"""关闭指定服务器会话。"""
|
||||||
transport = await self.exit_stack.enter_async_context(
|
session_stack = self._session_exit_stacks.pop(server_name, None)
|
||||||
sse_client(url=config.url, headers=config.headers)
|
self.sessions.pop(server_name, None)
|
||||||
)
|
self._session_last_used.pop(server_name, None)
|
||||||
elif config.command:
|
|
||||||
transport = await self.exit_stack.enter_async_context(
|
|
||||||
stdio_client(StdioServerParameters(**config.model_dump()))
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("Server config must have either url or command")
|
|
||||||
|
|
||||||
read, write = transport
|
if session_stack is not None:
|
||||||
self.session = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
await session_stack.aclose()
|
||||||
await self.session.initialize()
|
|
||||||
return self.session
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def _get_or_create_session(self, server_name: str) -> ClientSession:
|
||||||
await self.exit_stack.aclose()
|
"""获取可复用会话;若不存在或已过期则新建。"""
|
||||||
|
now = monotonic()
|
||||||
|
async with self._session_lock:
|
||||||
|
last_used = self._session_last_used.get(server_name)
|
||||||
|
session = self.sessions.get(server_name)
|
||||||
|
|
||||||
return SessionContext()
|
# 空闲超过阈值则销毁重建
|
||||||
|
if session is not None and last_used is not None:
|
||||||
|
if now - last_used > self._SESSION_TTL_SECONDS:
|
||||||
|
logger.info(f"服务器[{server_name}]会话空闲超过10分钟,重新创建")
|
||||||
|
await self._close_server_session(server_name)
|
||||||
|
session = None
|
||||||
|
|
||||||
|
if session is None:
|
||||||
|
session, session_stack = await self._create_server_session(server_name)
|
||||||
|
self.sessions[server_name] = session
|
||||||
|
self._session_exit_stacks[server_name] = session_stack
|
||||||
|
|
||||||
|
self._session_last_used[server_name] = now
|
||||||
|
return self.sessions[server_name]
|
||||||
|
|
||||||
|
async def _cleanup_expired_sessions(self):
|
||||||
|
"""回收空闲过期会话。"""
|
||||||
|
now = monotonic()
|
||||||
|
async with self._session_lock:
|
||||||
|
expired_servers = [
|
||||||
|
server_name
|
||||||
|
for server_name, last_used in self._session_last_used.items()
|
||||||
|
if now - last_used > self._SESSION_TTL_SECONDS
|
||||||
|
]
|
||||||
|
|
||||||
|
for server_name in expired_servers:
|
||||||
|
logger.info(f"回收空闲MCP会话[{server_name}]")
|
||||||
|
await self._close_server_session(server_name)
|
||||||
|
|
||||||
|
async def _session_cleanup_loop(self):
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(self._SESSION_CLEANUP_INTERVAL_SECONDS)
|
||||||
|
await self._cleanup_expired_sessions()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug("MCP会话清理任务已取消")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _ensure_cleanup_task(self):
|
||||||
|
if self._session_cleanup_task is None or self._session_cleanup_task.done():
|
||||||
|
self._session_cleanup_task = asyncio.create_task(self._session_cleanup_loop())
|
||||||
|
|
||||||
async def init_tools_cache(self):
|
async def init_tools_cache(self):
|
||||||
"""初始化工具列表缓存"""
|
"""初始化工具列表缓存"""
|
||||||
if not self._cache_initialized:
|
if not self._cache_initialized:
|
||||||
|
await self._ensure_cleanup_task()
|
||||||
available_tools = []
|
available_tools = []
|
||||||
logger.info(f"初始化工具列表缓存,需要连接{len(self.server_config)}个服务器")
|
logger.info(f"初始化工具列表缓存,需要连接{len(self.server_config)}个服务器")
|
||||||
for server_name in self.server_config.keys():
|
for server_name in self.server_config.keys():
|
||||||
logger.debug(f"正在从服务器[{server_name}]获取工具列表")
|
logger.debug(f"正在从服务器[{server_name}]获取工具列表")
|
||||||
async with self._create_session_context(server_name) as session:
|
session = await self._get_or_create_session(server_name)
|
||||||
response = await session.list_tools()
|
response = await session.list_tools()
|
||||||
tools = response.tools
|
tools = response.tools
|
||||||
logger.debug(f"在服务器[{server_name}]中找到{len(tools)}个工具")
|
logger.debug(f"在服务器[{server_name}]中找到{len(tools)}个工具")
|
||||||
|
|
||||||
available_tools.extend(
|
available_tools.extend(
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": f"mcp__{server_name}__{tool.name}",
|
"name": f"mcp__{server_name}__{tool.name}",
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
"parameters": tool.inputSchema,
|
"parameters": tool.inputSchema,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for tool in tools
|
for tool in tools
|
||||||
)
|
)
|
||||||
|
|
||||||
# 缓存工具列表
|
# 缓存工具列表
|
||||||
self._tools_cache = available_tools
|
self._tools_cache = available_tools
|
||||||
|
|
@ -149,7 +192,7 @@ class MCPClient:
|
||||||
return available_tools
|
return available_tools
|
||||||
|
|
||||||
async def call_tool(self, tool_name: str, tool_args: dict, group_id: int | None = None, bot_id: str | None = None):
|
async def call_tool(self, tool_name: str, tool_args: dict, group_id: int | None = None, bot_id: str | None = None):
|
||||||
"""按需连接调用工具,调用后立即断开"""
|
"""按需调用工具,MCP会话会在10分钟空闲后自动回收。"""
|
||||||
# 检查是否是OneBot内置工具
|
# 检查是否是OneBot内置工具
|
||||||
if tool_name.startswith("ob__"):
|
if tool_name.startswith("ob__"):
|
||||||
if group_id is None or bot_id is None:
|
if group_id is None or bot_id is None:
|
||||||
|
|
@ -168,14 +211,20 @@ class MCPClient:
|
||||||
real_tool_name = parts[2]
|
real_tool_name = parts[2]
|
||||||
logger.info(f"按需连接到服务器[{server_name}]调用工具[{real_tool_name}]")
|
logger.info(f"按需连接到服务器[{server_name}]调用工具[{real_tool_name}]")
|
||||||
|
|
||||||
async with self._create_session_context(server_name) as session:
|
try:
|
||||||
try:
|
await self._ensure_cleanup_task()
|
||||||
response = await asyncio.wait_for(session.call_tool(real_tool_name, tool_args), timeout=30)
|
session = await self._get_or_create_session(server_name)
|
||||||
logger.debug(f"工具[{real_tool_name}]调用完成,响应: {response}")
|
response = await asyncio.wait_for(session.call_tool(real_tool_name, tool_args), timeout=30)
|
||||||
return response.content
|
logger.debug(f"工具[{real_tool_name}]调用完成,响应: {response}")
|
||||||
except asyncio.TimeoutError:
|
return response.content
|
||||||
logger.error(f"调用工具[{real_tool_name}]超时")
|
except asyncio.TimeoutError:
|
||||||
return f"调用工具[{real_tool_name}]超时"
|
logger.error(f"调用工具[{real_tool_name}]超时")
|
||||||
|
return f"调用工具[{real_tool_name}]超时"
|
||||||
|
except (RuntimeError, ValueError, TypeError, OSError, ConnectionError) as e:
|
||||||
|
logger.opt(exception=e).error(f"调用工具[{real_tool_name}]失败,准备重置会话")
|
||||||
|
async with self._session_lock:
|
||||||
|
await self._close_server_session(server_name)
|
||||||
|
return f"调用工具[{real_tool_name}]失败: {e!s}"
|
||||||
|
|
||||||
# 未知工具类型
|
# 未知工具类型
|
||||||
return f"未知的工具类型: {tool_name}"
|
return f"未知的工具类型: {tool_name}"
|
||||||
|
|
@ -211,6 +260,19 @@ class MCPClient:
|
||||||
logger.debug("正在清理MCPClient资源")
|
logger.debug("正在清理MCPClient资源")
|
||||||
# 只清除缓存,不销毁单例
|
# 只清除缓存,不销毁单例
|
||||||
# self.clear_tools_cache() # 保留缓存,避免重复获取工具列表
|
# self.clear_tools_cache() # 保留缓存,避免重复获取工具列表
|
||||||
|
|
||||||
|
if self._session_cleanup_task is not None:
|
||||||
|
self._session_cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._session_cleanup_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._session_cleanup_task = None
|
||||||
|
|
||||||
|
async with self._session_lock:
|
||||||
|
for server_name in list(self.sessions.keys()):
|
||||||
|
await self._close_server_session(server_name)
|
||||||
|
|
||||||
await self.exit_stack.aclose()
|
await self.exit_stack.aclose()
|
||||||
# 重新初始化exit_stack以便后续使用
|
# 重新初始化exit_stack以便后续使用
|
||||||
self.exit_stack = AsyncExitStack()
|
self.exit_stack = AsyncExitStack()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue