Compare commits

..

No commits in common. "main" and "v0.5.2" have entirely different histories.
main ... v0.5.2

3 changed files with 69 additions and 148 deletions

View file

@ -395,14 +395,12 @@ async def process_messages(context_id: int, is_group: bool = True):
systemPrompt = "\n".join(system_lines) systemPrompt = "\n".join(system_lines)
if preset.support_mcp: if preset.support_mcp:
systemPrompt += "\n你也可以使用一些工具,下面是关于这些工具的额外说明:\n" systemPrompt += "你也可以使用一些工具,下面是关于这些工具的额外说明:\n"
for mcp_name, mcp_config in plugin_config.mcp_servers.items(): for mcp_name, mcp_config in plugin_config.mcp_servers.items():
if mcp_config.additional_prompt: if mcp_config.addtional_prompt:
systemPrompt += f"{mcp_name}{mcp_config.additional_prompt}" systemPrompt += f"{mcp_name}{mcp_config.addtional_prompt}"
systemPrompt += "\n" systemPrompt += "\n"
logger.debug(f"构建系统提示词:\n{systemPrompt}")
messages: list[ChatCompletionMessageParam] = [ messages: list[ChatCompletionMessageParam] = [
{"role": "system", "content": systemPrompt} {"role": "system", "content": systemPrompt}
] ]
@ -480,23 +478,8 @@ async def process_messages(context_id: int, is_group: bool = True):
new_messages.append(llm_reply) new_messages.append(llm_reply)
for tool_call in message.tool_calls: for tool_call in message.tool_calls:
logger.debug(f"处理工具调用:{tool_call.function.name} 参数:{tool_call.function.arguments}")
tool_name = tool_call.function.name tool_name = tool_call.function.name
try: tool_args = json.loads(tool_call.function.arguments)
tool_args = json.loads(tool_call.function.arguments)
except (json.JSONDecodeError, TypeError, ValueError) as e:
error_message = (
f"工具调用参数格式错误,无法解析 {tool_name} 的 arguments: {e!s}. "
f"原始参数: {tool_call.function.arguments}"
)
logger.warning(error_message)
new_messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": error_message,
})
continue
# 发送工具调用提示 # 发送工具调用提示
await handler.send(Message(f"正在使用{mcp_client.get_friendly_name(tool_name)}")) await handler.send(Message(f"正在使用{mcp_client.get_friendly_name(tool_name)}"))

View file

@ -29,7 +29,7 @@ class MCPServerConfig(BaseModel):
# 额外字段 # 额外字段
friendly_name: str | None = Field(None, description="MCP服务器友好名称") friendly_name: str | None = Field(None, description="MCP服务器友好名称")
additional_prompt: str | None = Field(None, description="额外提示词") addtional_prompt: str | None = Field(None, description="额外提示词")
class ScopedConfig(BaseModel): class ScopedConfig(BaseModel):
"""LLM Chat Plugin配置""" """LLM Chat Plugin配置"""

View file

@ -1,7 +1,5 @@
import asyncio import asyncio
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from time import monotonic
from typing import Any, cast
from mcp import ClientSession, StdioServerParameters from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
@ -15,8 +13,6 @@ from .onebottools import OneBotTools
class MCPClient: class MCPClient:
_instance = None _instance = None
_initialized = False _initialized = False
_SESSION_TTL_SECONDS = 600
_SESSION_CLEANUP_INTERVAL_SECONDS = 60
def __new__(cls, server_config: dict[str, MCPServerConfig] | None = None): def __new__(cls, server_config: dict[str, MCPServerConfig] | None = None):
if cls._instance is None: if cls._instance is None:
@ -34,10 +30,6 @@ class MCPClient:
self.server_config = server_config self.server_config = server_config
self.sessions = {} self.sessions = {}
self.exit_stack = AsyncExitStack() self.exit_stack = AsyncExitStack()
self._session_exit_stacks: dict[str, AsyncExitStack] = {}
self._session_last_used: dict[str, float] = {}
self._session_lock = asyncio.Lock()
self._session_cleanup_task: asyncio.Task | None = None
# 添加工具列表缓存 # 添加工具列表缓存
self._tools_cache: list | None = None self._tools_cache: list | None = None
self._cache_initialized = False self._cache_initialized = False
@ -63,115 +55,80 @@ class MCPClient:
return cls._instance return cls._instance
async def connect_to_servers(self): async def connect_to_servers(self):
await self._ensure_cleanup_task()
logger.info(f"开始连接{len(self.server_config)}个MCP服务器") logger.info(f"开始连接{len(self.server_config)}个MCP服务器")
for server_name in self.server_config: for server_name, config in self.server_config.items():
logger.debug(f"正在连接服务器[{server_name}]") logger.debug(f"正在连接服务器[{server_name}]")
await self._get_or_create_session(server_name) if config.url:
sse_transport = await self.exit_stack.enter_async_context(sse_client(url=config.url, headers=config.headers))
read, write = sse_transport
self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write))
await self.sessions[server_name].initialize()
elif config.command:
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(StdioServerParameters(**config.model_dump()))
)
read, write = stdio_transport
self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write))
await self.sessions[server_name].initialize()
else:
raise ValueError("Server config must have either url or command")
logger.info(f"已成功连接到MCP服务器[{server_name}]") logger.info(f"已成功连接到MCP服务器[{server_name}]")
async def _create_server_session(self, server_name: str) -> tuple[ClientSession, AsyncExitStack]: def _create_session_context(self, server_name: str):
"""创建并初始化一个新的服务器会话。""" """创建临时会话的异步上下文管理器"""
config = self.server_config[server_name] config = self.server_config[server_name]
session_stack = AsyncExitStack()
if config.url:
transport = await session_stack.enter_async_context(
sse_client(url=config.url, headers=config.headers)
)
elif config.command:
transport = await session_stack.enter_async_context(
cast(Any, stdio_client(StdioServerParameters(**config.model_dump())))
)
else:
raise ValueError("Server config must have either url or command")
read, write = transport class SessionContext:
session = await session_stack.enter_async_context(ClientSession(read, write)) def __init__(self):
await session.initialize() self.session = None
return session, session_stack self.exit_stack = AsyncExitStack()
async def _close_server_session(self, server_name: str): async def __aenter__(self):
"""关闭指定服务器会话。""" if config.url:
session_stack = self._session_exit_stacks.pop(server_name, None) transport = await self.exit_stack.enter_async_context(
self.sessions.pop(server_name, None) sse_client(url=config.url, headers=config.headers)
self._session_last_used.pop(server_name, None) )
elif config.command:
transport = await self.exit_stack.enter_async_context(
stdio_client(StdioServerParameters(**config.model_dump()))
)
else:
raise ValueError("Server config must have either url or command")
if session_stack is not None: read, write = transport
await session_stack.aclose() self.session = await self.exit_stack.enter_async_context(ClientSession(read, write))
await self.session.initialize()
return self.session
async def _get_or_create_session(self, server_name: str) -> ClientSession: async def __aexit__(self, exc_type, exc_val, exc_tb):
"""获取可复用会话;若不存在或已过期则新建。""" await self.exit_stack.aclose()
now = monotonic()
async with self._session_lock:
last_used = self._session_last_used.get(server_name)
session = self.sessions.get(server_name)
# 空闲超过阈值则销毁重建 return SessionContext()
if session is not None and last_used is not None:
if now - last_used > self._SESSION_TTL_SECONDS:
logger.info(f"服务器[{server_name}]会话空闲超过10分钟重新创建")
await self._close_server_session(server_name)
session = None
if session is None:
session, session_stack = await self._create_server_session(server_name)
self.sessions[server_name] = session
self._session_exit_stacks[server_name] = session_stack
self._session_last_used[server_name] = now
return self.sessions[server_name]
async def _cleanup_expired_sessions(self):
"""回收空闲过期会话。"""
now = monotonic()
async with self._session_lock:
expired_servers = [
server_name
for server_name, last_used in self._session_last_used.items()
if now - last_used > self._SESSION_TTL_SECONDS
]
for server_name in expired_servers:
logger.info(f"回收空闲MCP会话[{server_name}]")
await self._close_server_session(server_name)
async def _session_cleanup_loop(self):
try:
while True:
await asyncio.sleep(self._SESSION_CLEANUP_INTERVAL_SECONDS)
await self._cleanup_expired_sessions()
except asyncio.CancelledError:
logger.debug("MCP会话清理任务已取消")
raise
async def _ensure_cleanup_task(self):
if self._session_cleanup_task is None or self._session_cleanup_task.done():
self._session_cleanup_task = asyncio.create_task(self._session_cleanup_loop())
async def init_tools_cache(self): async def init_tools_cache(self):
"""初始化工具列表缓存""" """初始化工具列表缓存"""
if not self._cache_initialized: if not self._cache_initialized:
await self._ensure_cleanup_task()
available_tools = [] available_tools = []
logger.info(f"初始化工具列表缓存,需要连接{len(self.server_config)}个服务器") logger.info(f"初始化工具列表缓存,需要连接{len(self.server_config)}个服务器")
for server_name in self.server_config.keys(): for server_name in self.server_config.keys():
logger.debug(f"正在从服务器[{server_name}]获取工具列表") logger.debug(f"正在从服务器[{server_name}]获取工具列表")
session = await self._get_or_create_session(server_name) async with self._create_session_context(server_name) as session:
response = await session.list_tools() response = await session.list_tools()
tools = response.tools tools = response.tools
logger.debug(f"在服务器[{server_name}]中找到{len(tools)}个工具") logger.debug(f"在服务器[{server_name}]中找到{len(tools)}个工具")
available_tools.extend( available_tools.extend(
{ {
"type": "function", "type": "function",
"function": { "function": {
"name": f"mcp__{server_name}__{tool.name}", "name": f"mcp__{server_name}__{tool.name}",
"description": tool.description, "description": tool.description,
"parameters": tool.inputSchema, "parameters": tool.inputSchema,
}, },
} }
for tool in tools for tool in tools
) )
# 缓存工具列表 # 缓存工具列表
self._tools_cache = available_tools self._tools_cache = available_tools
@ -192,7 +149,7 @@ class MCPClient:
return available_tools return available_tools
async def call_tool(self, tool_name: str, tool_args: dict, group_id: int | None = None, bot_id: str | None = None): async def call_tool(self, tool_name: str, tool_args: dict, group_id: int | None = None, bot_id: str | None = None):
"""按需调用工具MCP会话会在10分钟空闲后自动回收。""" """按需连接调用工具,调用后立即断开"""
# 检查是否是OneBot内置工具 # 检查是否是OneBot内置工具
if tool_name.startswith("ob__"): if tool_name.startswith("ob__"):
if group_id is None or bot_id is None: if group_id is None or bot_id is None:
@ -211,20 +168,14 @@ class MCPClient:
real_tool_name = parts[2] real_tool_name = parts[2]
logger.info(f"按需连接到服务器[{server_name}]调用工具[{real_tool_name}]") logger.info(f"按需连接到服务器[{server_name}]调用工具[{real_tool_name}]")
try: async with self._create_session_context(server_name) as session:
await self._ensure_cleanup_task() try:
session = await self._get_or_create_session(server_name) response = await asyncio.wait_for(session.call_tool(real_tool_name, tool_args), timeout=30)
response = await asyncio.wait_for(session.call_tool(real_tool_name, tool_args), timeout=30) logger.debug(f"工具[{real_tool_name}]调用完成,响应: {response}")
logger.debug(f"工具[{real_tool_name}]调用完成,响应: {response}") return response.content
return response.content except asyncio.TimeoutError:
except asyncio.TimeoutError: logger.error(f"调用工具[{real_tool_name}]超时")
logger.error(f"调用工具[{real_tool_name}]超时") return f"调用工具[{real_tool_name}]超时"
return f"调用工具[{real_tool_name}]超时"
except (RuntimeError, ValueError, TypeError, OSError, ConnectionError) as e:
logger.opt(exception=e).error(f"调用工具[{real_tool_name}]失败,准备重置会话")
async with self._session_lock:
await self._close_server_session(server_name)
return f"调用工具[{real_tool_name}]失败: {e!s}"
# 未知工具类型 # 未知工具类型
return f"未知的工具类型: {tool_name}" return f"未知的工具类型: {tool_name}"
@ -260,19 +211,6 @@ class MCPClient:
logger.debug("正在清理MCPClient资源") logger.debug("正在清理MCPClient资源")
# 只清除缓存,不销毁单例 # 只清除缓存,不销毁单例
# self.clear_tools_cache() # 保留缓存,避免重复获取工具列表 # self.clear_tools_cache() # 保留缓存,避免重复获取工具列表
if self._session_cleanup_task is not None:
self._session_cleanup_task.cancel()
try:
await self._session_cleanup_task
except asyncio.CancelledError:
pass
self._session_cleanup_task = None
async with self._session_lock:
for server_name in list(self.sessions.keys()):
await self._close_server_session(server_name)
await self.exit_stack.aclose() await self.exit_stack.aclose()
# 重新初始化exit_stack以便后续使用 # 重新初始化exit_stack以便后续使用
self.exit_stack = AsyncExitStack() self.exit_stack = AsyncExitStack()