mirror of
https://github.com/FuQuan233/nonebot-plugin-llmchat.git
synced 2026-03-26 22:52:30 +00:00
317 lines
14 KiB
Python
317 lines
14 KiB
Python
import asyncio
|
||
from contextlib import AsyncExitStack
|
||
|
||
from mcp import ClientSession, StdioServerParameters
|
||
from mcp.client.sse import sse_client
|
||
from mcp.client.stdio import stdio_client
|
||
try:
|
||
from mcp.client.streamable_http import streamable_http_client as streamablehttp_client
|
||
except:
|
||
from mcp.client.streamable_http import streamablehttp_client
|
||
from nonebot import logger
|
||
|
||
from .config import MCPServerConfig, PresetConfig, ScopedConfig, transportType
|
||
from .onebottools import OneBotTools
|
||
from .scheduler import SchedulerManager
|
||
from .submodel_caller import SubModelCaller
|
||
|
||
|
||
class MCPClient:
|
||
_instance = None
|
||
_initialized = False
|
||
|
||
def __new__(cls, server_config: dict[str, MCPServerConfig] | None = None, plugin_config: ScopedConfig | None = None):
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self, server_config: dict[str, MCPServerConfig] | None = None, plugin_config: ScopedConfig | None = None):
|
||
if self._initialized:
|
||
return
|
||
|
||
if server_config is None:
|
||
raise ValueError("server_config must be provided for first initialization")
|
||
|
||
logger.info(f"正在初始化MCPClient单例,共有{len(server_config)}个服务器配置")
|
||
self.server_config = server_config
|
||
self.plugin_config = plugin_config
|
||
self.sessions = {}
|
||
self.exit_stack = AsyncExitStack()
|
||
# 添加工具列表缓存
|
||
self._tools_cache: list | None = None
|
||
self._cache_initialized = False
|
||
# 初始化OneBot工具
|
||
self.onebot_tools = OneBotTools()
|
||
# 初始化定时任务管理器
|
||
self.scheduler_manager = SchedulerManager.get_instance()
|
||
# 初始化子模型调用器(如果有 plugin_config)
|
||
self.submodel_caller = SubModelCaller.get_instance(plugin_config) if plugin_config else None
|
||
self._initialized = True
|
||
logger.debug("MCPClient单例初始化成功")
|
||
|
||
@classmethod
|
||
def get_instance(cls, server_config: dict[str, MCPServerConfig] | None = None, plugin_config: ScopedConfig | None = None):
|
||
"""获取MCPClient实例"""
|
||
if cls._instance is None:
|
||
if server_config is None:
|
||
raise ValueError("server_config must be provided for first initialization")
|
||
cls._instance = cls(server_config, plugin_config)
|
||
return cls._instance
|
||
|
||
@classmethod
|
||
def instance(cls):
|
||
"""快速获取已初始化的MCPClient实例,如果未初始化则抛出异常"""
|
||
if cls._instance is None:
|
||
raise RuntimeError("MCPClient has not been initialized. Call get_instance() first.")
|
||
return cls._instance
|
||
|
||
async def connect_to_servers(self):
|
||
logger.info(f"开始连接{len(self.server_config)}个MCP服务器")
|
||
for server_name, config in self.server_config.items():
|
||
logger.debug(f"正在连接服务器[{server_name}]")
|
||
if config.url:
|
||
sse_transport = await self.exit_stack.enter_async_context(sse_client(url=config.url, headers=config.headers))
|
||
read, write = sse_transport
|
||
self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
||
await self.sessions[server_name].initialize()
|
||
elif config.command:
|
||
stdio_transport = await self.exit_stack.enter_async_context(
|
||
stdio_client(StdioServerParameters(**config.model_dump()))
|
||
)
|
||
read, write = stdio_transport
|
||
self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
||
await self.sessions[server_name].initialize()
|
||
else:
|
||
raise ValueError("Server config must have either url or command")
|
||
|
||
logger.info(f"已成功连接到MCP服务器[{server_name}]")
|
||
|
||
def _create_session_context(self, server_name: str):
|
||
"""创建临时会话的异步上下文管理器"""
|
||
config = self.server_config[server_name]
|
||
|
||
class SessionContext:
|
||
def __init__(self):
|
||
self.session = None
|
||
self.exit_stack = AsyncExitStack()
|
||
|
||
async def __aenter__(self):
|
||
if config.transport_type is None:
|
||
if config.url:
|
||
config.transport_type = transportType.sse
|
||
elif config.command:
|
||
config.transport_type = transportType.stdio
|
||
else:
|
||
raise ValueError("Server config must have either url or command")
|
||
|
||
match config.transport_type:
|
||
case transportType.sse:
|
||
transport = await self.exit_stack.enter_async_context(
|
||
sse_client(url=config.url, headers=config.headers)
|
||
)
|
||
read, write = transport
|
||
case transportType.stdio:
|
||
transport = await self.exit_stack.enter_async_context(
|
||
stdio_client(StdioServerParameters(**config.model_dump()))
|
||
)
|
||
read, write = transport
|
||
case transportType.streamablehttp:
|
||
transport = await self.exit_stack.enter_async_context(
|
||
streamablehttp_client(url=config.url, headers=config.headers)
|
||
)
|
||
read, write, session_callback = transport
|
||
case _:
|
||
raise ValueError("Server config must have either url or command")
|
||
|
||
self.session = await self.exit_stack.enter_async_context(ClientSession(read, write))
|
||
await self.session.initialize()
|
||
return self.session
|
||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||
await self.exit_stack.aclose()
|
||
|
||
return SessionContext()
|
||
|
||
async def init_tools_cache(self):
|
||
"""初始化工具列表缓存"""
|
||
if not self._cache_initialized:
|
||
available_tools = []
|
||
logger.info(f"初始化工具列表缓存,需要连接{len(self.server_config)}个服务器")
|
||
for server_name in self.server_config.keys():
|
||
logger.debug(f"正在从服务器[{server_name}]获取工具列表")
|
||
async with self._create_session_context(server_name) as session:
|
||
response = await session.list_tools()
|
||
tools = response.tools
|
||
logger.debug(f"在服务器[{server_name}]中找到{len(tools)}个工具")
|
||
|
||
available_tools.extend(
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": f"mcp__{server_name}__{tool.name}",
|
||
"description": tool.description,
|
||
"parameters": tool.inputSchema,
|
||
},
|
||
}
|
||
for tool in tools
|
||
)
|
||
|
||
# 缓存工具列表
|
||
self._tools_cache = available_tools
|
||
self._cache_initialized = True
|
||
|
||
logger.info(f"工具列表缓存完成,共缓存{len(available_tools)}个工具")
|
||
|
||
|
||
|
||
async def get_available_tools(self, is_group: bool, current_preset: PresetConfig | None = None):
|
||
"""获取可用工具列表,使用缓存机制
|
||
|
||
Args:
|
||
is_group: 是否群聊场景
|
||
current_preset: 当前使用的预设配置(用于获取子模型工具)
|
||
"""
|
||
await self.init_tools_cache()
|
||
available_tools = self._tools_cache.copy() if self._tools_cache else []
|
||
if is_group:
|
||
# 群聊场景,包含OneBot工具和MCP工具
|
||
available_tools.extend(self.onebot_tools.get_available_tools())
|
||
# 添加定时任务工具(群聊和私聊都可用)
|
||
available_tools.extend(self.scheduler_manager.get_available_tools())
|
||
# 添加子模型调用工具(根据当前预设的 call_model_list 动态生成)
|
||
if self.submodel_caller and current_preset:
|
||
submodel_tools = self.submodel_caller.get_available_tools(current_preset)
|
||
available_tools.extend(submodel_tools)
|
||
if submodel_tools:
|
||
logger.debug(f"添加了 {len(submodel_tools)} 个子模型调用工具")
|
||
logger.debug(f"获取可用工具列表,共{len(available_tools)}个工具")
|
||
return available_tools
|
||
|
||
async def call_tool(
|
||
self,
|
||
tool_name: str,
|
||
tool_args: dict,
|
||
group_id: int | None = None,
|
||
bot_id: str | None = None,
|
||
user_id: int | None = None,
|
||
is_group: bool = True,
|
||
current_preset: PresetConfig | None = None
|
||
):
|
||
"""按需连接调用工具,调用后立即断开
|
||
|
||
Args:
|
||
tool_name: 工具名称
|
||
tool_args: 工具参数
|
||
group_id: 群号(群聊时必需)
|
||
bot_id: 机器人ID
|
||
user_id: 用户ID
|
||
is_group: 是否群聊
|
||
current_preset: 当前使用的预设配置(子模型调用时必需)
|
||
"""
|
||
# 检查是否是OneBot内置工具
|
||
if tool_name.startswith("ob__"):
|
||
if group_id is None or bot_id is None:
|
||
return "QQ工具需要提供group_id和bot_id参数"
|
||
logger.info(f"调用OneBot工具[{tool_name}]")
|
||
return await self.onebot_tools.call_tool(tool_name, tool_args, group_id, bot_id)
|
||
|
||
# 检查是否是定时任务工具
|
||
if tool_name.startswith("scheduler__"):
|
||
context_id = group_id if is_group else user_id
|
||
if context_id is None or user_id is None:
|
||
return "定时任务工具需要提供context_id和user_id参数"
|
||
logger.info(f"调用定时任务工具[{tool_name}]")
|
||
return await self.scheduler_manager.call_tool(
|
||
tool_name, tool_args, context_id, is_group, user_id
|
||
)
|
||
|
||
# 检查是否是子模型调用工具
|
||
if tool_name.startswith("submodel__"):
|
||
if not self.submodel_caller:
|
||
return "子模型调用器未初始化"
|
||
if not current_preset:
|
||
return "子模型调用需要提供 current_preset 参数"
|
||
logger.info(f"调用子模型工具[{tool_name}]")
|
||
result = await self.submodel_caller.call_tool(tool_name, tool_args, current_preset)
|
||
# 返回结构化结果,让上层处理
|
||
return result
|
||
|
||
# 检查是否是MCP工具
|
||
if tool_name.startswith("mcp__"):
|
||
# MCP工具处理:mcp__server_name__tool_name
|
||
parts = tool_name.split("__")
|
||
if len(parts) != 3 or parts[0] != "mcp":
|
||
return f"MCP工具名称格式错误: {tool_name}"
|
||
|
||
server_name = parts[1]
|
||
real_tool_name = parts[2]
|
||
logger.info(f"按需连接到服务器[{server_name}]调用工具[{real_tool_name}]")
|
||
|
||
async with self._create_session_context(server_name) as session:
|
||
try:
|
||
response = await asyncio.wait_for(session.call_tool(real_tool_name, tool_args), timeout=30)
|
||
logger.debug(f"工具[{real_tool_name}]调用完成,响应: {response}")
|
||
return response.content
|
||
except asyncio.TimeoutError:
|
||
logger.error(f"调用工具[{real_tool_name}]超时")
|
||
return f"调用工具[{real_tool_name}]超时"
|
||
|
||
# 未知工具类型
|
||
return f"未知的工具类型: {tool_name}"
|
||
|
||
def get_friendly_name(self, tool_name: str):
|
||
logger.debug(tool_name)
|
||
# 检查是否是OneBot内置工具
|
||
if tool_name.startswith("ob__"):
|
||
return self.onebot_tools.get_friendly_name(tool_name)
|
||
|
||
# 检查是否是定时任务工具
|
||
if tool_name.startswith("scheduler__"):
|
||
return self.scheduler_manager.get_friendly_name(tool_name)
|
||
|
||
# 检查是否是子模型调用工具
|
||
if tool_name.startswith("submodel__"):
|
||
if self.submodel_caller:
|
||
return self.submodel_caller.get_friendly_name(tool_name)
|
||
return tool_name
|
||
|
||
# 检查是否是MCP工具
|
||
if tool_name.startswith("mcp__"):
|
||
# MCP工具处理:mcp__server_name__tool_name
|
||
parts = tool_name.split("__")
|
||
if len(parts) != 3 or parts[0] != "mcp":
|
||
return tool_name # 格式错误时返回原名称
|
||
|
||
server_name = parts[1]
|
||
real_tool_name = parts[2]
|
||
return (self.server_config[server_name].friendly_name or server_name) + " - " + real_tool_name
|
||
|
||
# 未知工具类型,返回原名称
|
||
return tool_name
|
||
|
||
def clear_tools_cache(self):
|
||
"""清除工具列表缓存"""
|
||
logger.info("清除工具列表缓存")
|
||
self._tools_cache = None
|
||
self._cache_initialized = False
|
||
|
||
async def cleanup(self):
|
||
"""清理资源(不销毁单例)"""
|
||
logger.debug("正在清理MCPClient资源")
|
||
# 只清除缓存,不销毁单例
|
||
# self.clear_tools_cache() # 保留缓存,避免重复获取工具列表
|
||
await self.exit_stack.aclose()
|
||
# 重新初始化exit_stack以便后续使用
|
||
self.exit_stack = AsyncExitStack()
|
||
logger.debug("MCPClient资源清理完成")
|
||
|
||
@classmethod
|
||
async def destroy_instance(cls):
|
||
"""完全销毁单例实例(仅在应用关闭时使用)"""
|
||
if cls._instance is not None:
|
||
logger.info("销毁MCPClient单例")
|
||
await cls._instance.cleanup()
|
||
cls._instance.clear_tools_cache()
|
||
cls._instance = None
|
||
cls._initialized = False
|
||
logger.debug("MCPClient单例已销毁")
|