mirror of
https://github.com/FuQuan233/nonebot-plugin-llmchat.git
synced 2026-06-28 00:02:04 +00:00
290 lines
12 KiB
Python
290 lines
12 KiB
Python
import asyncio
|
||
from contextlib import AsyncExitStack
|
||
from time import monotonic
|
||
from typing import Any, cast
|
||
|
||
from mcp import ClientSession, StdioServerParameters
|
||
from mcp.client.sse import sse_client
|
||
from mcp.client.stdio import stdio_client
|
||
from nonebot import logger
|
||
|
||
from .config import MCPServerConfig
|
||
from .onebottools import OneBotTools
|
||
|
||
|
||
class MCPClient:
|
||
_instance = None
|
||
_initialized = False
|
||
_SESSION_TTL_SECONDS = 600
|
||
_SESSION_CLEANUP_INTERVAL_SECONDS = 60
|
||
|
||
def __new__(cls, server_config: dict[str, MCPServerConfig] | None = None):
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self, server_config: dict[str, MCPServerConfig] | None = None):
|
||
if self._initialized:
|
||
return
|
||
|
||
if server_config is None:
|
||
raise ValueError("server_config must be provided for first initialization")
|
||
|
||
logger.info(f"正在初始化MCPClient单例,共有{len(server_config)}个服务器配置")
|
||
self.server_config = server_config
|
||
self.sessions = {}
|
||
self.exit_stack = AsyncExitStack()
|
||
self._session_exit_stacks: dict[str, AsyncExitStack] = {}
|
||
self._session_last_used: dict[str, float] = {}
|
||
self._session_lock = asyncio.Lock()
|
||
self._session_cleanup_task: asyncio.Task | None = None
|
||
# 添加工具列表缓存
|
||
self._tools_cache: list | None = None
|
||
self._cache_initialized = False
|
||
# 初始化OneBot工具
|
||
self.onebot_tools = OneBotTools()
|
||
self._initialized = True
|
||
logger.debug("MCPClient单例初始化成功")
|
||
|
||
@classmethod
|
||
def get_instance(cls, server_config: dict[str, MCPServerConfig] | None = None):
|
||
"""获取MCPClient实例"""
|
||
if cls._instance is None:
|
||
if server_config is None:
|
||
raise ValueError("server_config must be provided for first initialization")
|
||
cls._instance = cls(server_config)
|
||
return cls._instance
|
||
|
||
@classmethod
|
||
def instance(cls):
|
||
"""快速获取已初始化的MCPClient实例,如果未初始化则抛出异常"""
|
||
if cls._instance is None:
|
||
raise RuntimeError("MCPClient has not been initialized. Call get_instance() first.")
|
||
return cls._instance
|
||
|
||
async def connect_to_servers(self):
|
||
await self._ensure_cleanup_task()
|
||
logger.info(f"开始连接{len(self.server_config)}个MCP服务器")
|
||
for server_name in self.server_config:
|
||
logger.debug(f"正在连接服务器[{server_name}]")
|
||
await self._get_or_create_session(server_name)
|
||
logger.info(f"已成功连接到MCP服务器[{server_name}]")
|
||
|
||
async def _create_server_session(self, server_name: str) -> tuple[ClientSession, AsyncExitStack]:
|
||
"""创建并初始化一个新的服务器会话。"""
|
||
config = self.server_config[server_name]
|
||
session_stack = AsyncExitStack()
|
||
if config.url:
|
||
transport = await session_stack.enter_async_context(
|
||
sse_client(url=config.url, headers=config.headers)
|
||
)
|
||
elif config.command:
|
||
transport = await session_stack.enter_async_context(
|
||
cast(Any, stdio_client(StdioServerParameters(**config.model_dump())))
|
||
)
|
||
else:
|
||
raise ValueError("Server config must have either url or command")
|
||
|
||
read, write = transport
|
||
session = await session_stack.enter_async_context(ClientSession(read, write))
|
||
await session.initialize()
|
||
return session, session_stack
|
||
|
||
async def _close_server_session(self, server_name: str):
|
||
"""关闭指定服务器会话。"""
|
||
session_stack = self._session_exit_stacks.pop(server_name, None)
|
||
self.sessions.pop(server_name, None)
|
||
self._session_last_used.pop(server_name, None)
|
||
|
||
if session_stack is not None:
|
||
await session_stack.aclose()
|
||
|
||
async def _get_or_create_session(self, server_name: str) -> ClientSession:
|
||
"""获取可复用会话;若不存在或已过期则新建。"""
|
||
now = monotonic()
|
||
async with self._session_lock:
|
||
last_used = self._session_last_used.get(server_name)
|
||
session = self.sessions.get(server_name)
|
||
|
||
# 空闲超过阈值则销毁重建
|
||
if session is not None and last_used is not None:
|
||
if now - last_used > self._SESSION_TTL_SECONDS:
|
||
logger.info(f"服务器[{server_name}]会话空闲超过10分钟,重新创建")
|
||
await self._close_server_session(server_name)
|
||
session = None
|
||
|
||
if session is None:
|
||
session, session_stack = await self._create_server_session(server_name)
|
||
self.sessions[server_name] = session
|
||
self._session_exit_stacks[server_name] = session_stack
|
||
|
||
self._session_last_used[server_name] = now
|
||
return self.sessions[server_name]
|
||
|
||
async def _cleanup_expired_sessions(self):
|
||
"""回收空闲过期会话。"""
|
||
now = monotonic()
|
||
async with self._session_lock:
|
||
expired_servers = [
|
||
server_name
|
||
for server_name, last_used in self._session_last_used.items()
|
||
if now - last_used > self._SESSION_TTL_SECONDS
|
||
]
|
||
|
||
for server_name in expired_servers:
|
||
logger.info(f"回收空闲MCP会话[{server_name}]")
|
||
await self._close_server_session(server_name)
|
||
|
||
async def _session_cleanup_loop(self):
|
||
try:
|
||
while True:
|
||
await asyncio.sleep(self._SESSION_CLEANUP_INTERVAL_SECONDS)
|
||
await self._cleanup_expired_sessions()
|
||
except asyncio.CancelledError:
|
||
logger.debug("MCP会话清理任务已取消")
|
||
raise
|
||
|
||
async def _ensure_cleanup_task(self):
|
||
if self._session_cleanup_task is None or self._session_cleanup_task.done():
|
||
self._session_cleanup_task = asyncio.create_task(self._session_cleanup_loop())
|
||
|
||
async def init_tools_cache(self):
|
||
"""初始化工具列表缓存"""
|
||
if not self._cache_initialized:
|
||
await self._ensure_cleanup_task()
|
||
available_tools = []
|
||
logger.info(f"初始化工具列表缓存,需要连接{len(self.server_config)}个服务器")
|
||
for server_name in self.server_config.keys():
|
||
logger.debug(f"正在从服务器[{server_name}]获取工具列表")
|
||
session = await self._get_or_create_session(server_name)
|
||
response = await session.list_tools()
|
||
tools = response.tools
|
||
logger.debug(f"在服务器[{server_name}]中找到{len(tools)}个工具")
|
||
|
||
available_tools.extend(
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": f"mcp__{server_name}__{tool.name}",
|
||
"description": tool.description,
|
||
"parameters": tool.inputSchema,
|
||
},
|
||
}
|
||
for tool in tools
|
||
)
|
||
|
||
# 缓存工具列表
|
||
self._tools_cache = available_tools
|
||
self._cache_initialized = True
|
||
|
||
logger.info(f"工具列表缓存完成,共缓存{len(available_tools)}个工具")
|
||
|
||
|
||
|
||
async def get_available_tools(self, is_group: bool):
|
||
"""获取可用工具列表,使用缓存机制"""
|
||
await self.init_tools_cache()
|
||
available_tools = self._tools_cache.copy() if self._tools_cache else []
|
||
if is_group:
|
||
# 群聊场景,包含OneBot工具和MCP工具
|
||
available_tools.extend(self.onebot_tools.get_available_tools())
|
||
logger.debug(f"获取可用工具列表,共{len(available_tools)}个工具")
|
||
return available_tools
|
||
|
||
async def call_tool(self, tool_name: str, tool_args: dict, group_id: int | None = None, bot_id: str | None = None):
|
||
"""按需调用工具,MCP会话会在10分钟空闲后自动回收。"""
|
||
# 检查是否是OneBot内置工具
|
||
if tool_name.startswith("ob__"):
|
||
if group_id is None or bot_id is None:
|
||
return "QQ工具需要提供group_id和bot_id参数"
|
||
logger.info(f"调用OneBot工具[{tool_name}]")
|
||
return await self.onebot_tools.call_tool(tool_name, tool_args, group_id, bot_id)
|
||
|
||
# 检查是否是MCP工具
|
||
if tool_name.startswith("mcp__"):
|
||
# MCP工具处理:mcp__server_name__tool_name
|
||
parts = tool_name.split("__")
|
||
if len(parts) != 3 or parts[0] != "mcp":
|
||
return f"MCP工具名称格式错误: {tool_name}"
|
||
|
||
server_name = parts[1]
|
||
real_tool_name = parts[2]
|
||
logger.info(f"按需连接到服务器[{server_name}]调用工具[{real_tool_name}]")
|
||
|
||
try:
|
||
await self._ensure_cleanup_task()
|
||
session = await self._get_or_create_session(server_name)
|
||
response = await asyncio.wait_for(session.call_tool(real_tool_name, tool_args), timeout=30)
|
||
logger.debug(f"工具[{real_tool_name}]调用完成,响应: {response}")
|
||
return response.content
|
||
except asyncio.TimeoutError:
|
||
logger.error(f"调用工具[{real_tool_name}]超时")
|
||
return f"调用工具[{real_tool_name}]超时"
|
||
except (RuntimeError, ValueError, TypeError, OSError, ConnectionError) as e:
|
||
logger.opt(exception=e).error(f"调用工具[{real_tool_name}]失败,准备重置会话")
|
||
async with self._session_lock:
|
||
await self._close_server_session(server_name)
|
||
return f"调用工具[{real_tool_name}]失败: {e!s}"
|
||
|
||
# 未知工具类型
|
||
return f"未知的工具类型: {tool_name}"
|
||
|
||
def get_friendly_name(self, tool_name: str):
|
||
logger.debug(tool_name)
|
||
# 检查是否是OneBot内置工具
|
||
if tool_name.startswith("ob__"):
|
||
return self.onebot_tools.get_friendly_name(tool_name)
|
||
|
||
# 检查是否是MCP工具
|
||
if tool_name.startswith("mcp__"):
|
||
# MCP工具处理:mcp__server_name__tool_name
|
||
parts = tool_name.split("__")
|
||
if len(parts) != 3 or parts[0] != "mcp":
|
||
return tool_name # 格式错误时返回原名称
|
||
|
||
server_name = parts[1]
|
||
real_tool_name = parts[2]
|
||
return (self.server_config[server_name].friendly_name or server_name) + " - " + real_tool_name
|
||
|
||
# 未知工具类型,返回原名称
|
||
return tool_name
|
||
|
||
def clear_tools_cache(self):
|
||
"""清除工具列表缓存"""
|
||
logger.info("清除工具列表缓存")
|
||
self._tools_cache = None
|
||
self._cache_initialized = False
|
||
|
||
async def cleanup(self):
|
||
"""清理资源(不销毁单例)"""
|
||
logger.debug("正在清理MCPClient资源")
|
||
# 只清除缓存,不销毁单例
|
||
# self.clear_tools_cache() # 保留缓存,避免重复获取工具列表
|
||
|
||
if self._session_cleanup_task is not None:
|
||
self._session_cleanup_task.cancel()
|
||
try:
|
||
await self._session_cleanup_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self._session_cleanup_task = None
|
||
|
||
async with self._session_lock:
|
||
for server_name in list(self.sessions.keys()):
|
||
await self._close_server_session(server_name)
|
||
|
||
await self.exit_stack.aclose()
|
||
# 重新初始化exit_stack以便后续使用
|
||
self.exit_stack = AsyncExitStack()
|
||
logger.debug("MCPClient资源清理完成")
|
||
|
||
@classmethod
|
||
async def destroy_instance(cls):
|
||
"""完全销毁单例实例(仅在应用关闭时使用)"""
|
||
if cls._instance is not None:
|
||
logger.info("销毁MCPClient单例")
|
||
await cls._instance.cleanup()
|
||
cls._instance.clear_tools_cache()
|
||
cls._instance = None
|
||
cls._initialized = False
|
||
logger.debug("MCPClient单例已销毁")
|