Compare commits

..

No commits in common. "main" and "v0.5.2" have entirely different histories.
main ... v0.5.2

3 changed files with 69 additions and 148 deletions

View file

@ -395,14 +395,12 @@ async def process_messages(context_id: int, is_group: bool = True):
systemPrompt = "\n".join(system_lines)
if preset.support_mcp:
systemPrompt += "\n你也可以使用一些工具,下面是关于这些工具的额外说明:\n"
systemPrompt += "你也可以使用一些工具,下面是关于这些工具的额外说明:\n"
for mcp_name, mcp_config in plugin_config.mcp_servers.items():
if mcp_config.additional_prompt:
systemPrompt += f"{mcp_name}{mcp_config.additional_prompt}"
if mcp_config.addtional_prompt:
systemPrompt += f"{mcp_name}{mcp_config.addtional_prompt}"
systemPrompt += "\n"
logger.debug(f"构建系统提示词:\n{systemPrompt}")
messages: list[ChatCompletionMessageParam] = [
{"role": "system", "content": systemPrompt}
]
@ -480,23 +478,8 @@ async def process_messages(context_id: int, is_group: bool = True):
new_messages.append(llm_reply)
for tool_call in message.tool_calls:
logger.debug(f"处理工具调用:{tool_call.function.name} 参数:{tool_call.function.arguments}")
tool_name = tool_call.function.name
try:
tool_args = json.loads(tool_call.function.arguments)
except (json.JSONDecodeError, TypeError, ValueError) as e:
error_message = (
f"工具调用参数格式错误,无法解析 {tool_name} 的 arguments: {e!s}. "
f"原始参数: {tool_call.function.arguments}"
)
logger.warning(error_message)
new_messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": error_message,
})
continue
tool_args = json.loads(tool_call.function.arguments)
# 发送工具调用提示
await handler.send(Message(f"正在使用{mcp_client.get_friendly_name(tool_name)}"))

View file

@ -29,7 +29,7 @@ class MCPServerConfig(BaseModel):
# 额外字段
friendly_name: str | None = Field(None, description="MCP服务器友好名称")
additional_prompt: str | None = Field(None, description="额外提示词")
addtional_prompt: str | None = Field(None, description="额外提示词")
class ScopedConfig(BaseModel):
"""LLM Chat Plugin配置"""

View file

@ -1,7 +1,5 @@
import asyncio
from contextlib import AsyncExitStack
from time import monotonic
from typing import Any, cast
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
@ -15,8 +13,6 @@ from .onebottools import OneBotTools
class MCPClient:
_instance = None
_initialized = False
_SESSION_TTL_SECONDS = 600
_SESSION_CLEANUP_INTERVAL_SECONDS = 60
def __new__(cls, server_config: dict[str, MCPServerConfig] | None = None):
if cls._instance is None:
@ -34,10 +30,6 @@ class MCPClient:
self.server_config = server_config
self.sessions = {}
self.exit_stack = AsyncExitStack()
self._session_exit_stacks: dict[str, AsyncExitStack] = {}
self._session_last_used: dict[str, float] = {}
self._session_lock = asyncio.Lock()
self._session_cleanup_task: asyncio.Task | None = None
# 添加工具列表缓存
self._tools_cache: list | None = None
self._cache_initialized = False
@ -63,115 +55,80 @@ class MCPClient:
return cls._instance
async def connect_to_servers(self):
await self._ensure_cleanup_task()
logger.info(f"开始连接{len(self.server_config)}个MCP服务器")
for server_name in self.server_config:
for server_name, config in self.server_config.items():
logger.debug(f"正在连接服务器[{server_name}]")
await self._get_or_create_session(server_name)
if config.url:
sse_transport = await self.exit_stack.enter_async_context(sse_client(url=config.url, headers=config.headers))
read, write = sse_transport
self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write))
await self.sessions[server_name].initialize()
elif config.command:
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(StdioServerParameters(**config.model_dump()))
)
read, write = stdio_transport
self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write))
await self.sessions[server_name].initialize()
else:
raise ValueError("Server config must have either url or command")
logger.info(f"已成功连接到MCP服务器[{server_name}]")
async def _create_server_session(self, server_name: str) -> tuple[ClientSession, AsyncExitStack]:
"""创建并初始化一个新的服务器会话。"""
def _create_session_context(self, server_name: str):
"""创建临时会话的异步上下文管理器"""
config = self.server_config[server_name]
session_stack = AsyncExitStack()
if config.url:
transport = await session_stack.enter_async_context(
sse_client(url=config.url, headers=config.headers)
)
elif config.command:
transport = await session_stack.enter_async_context(
cast(Any, stdio_client(StdioServerParameters(**config.model_dump())))
)
else:
raise ValueError("Server config must have either url or command")
read, write = transport
session = await session_stack.enter_async_context(ClientSession(read, write))
await session.initialize()
return session, session_stack
class SessionContext:
def __init__(self):
self.session = None
self.exit_stack = AsyncExitStack()
async def _close_server_session(self, server_name: str):
"""关闭指定服务器会话。"""
session_stack = self._session_exit_stacks.pop(server_name, None)
self.sessions.pop(server_name, None)
self._session_last_used.pop(server_name, None)
async def __aenter__(self):
if config.url:
transport = await self.exit_stack.enter_async_context(
sse_client(url=config.url, headers=config.headers)
)
elif config.command:
transport = await self.exit_stack.enter_async_context(
stdio_client(StdioServerParameters(**config.model_dump()))
)
else:
raise ValueError("Server config must have either url or command")
if session_stack is not None:
await session_stack.aclose()
read, write = transport
self.session = await self.exit_stack.enter_async_context(ClientSession(read, write))
await self.session.initialize()
return self.session
async def _get_or_create_session(self, server_name: str) -> ClientSession:
"""获取可复用会话;若不存在或已过期则新建。"""
now = monotonic()
async with self._session_lock:
last_used = self._session_last_used.get(server_name)
session = self.sessions.get(server_name)
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.exit_stack.aclose()
# 空闲超过阈值则销毁重建
if session is not None and last_used is not None:
if now - last_used > self._SESSION_TTL_SECONDS:
logger.info(f"服务器[{server_name}]会话空闲超过10分钟重新创建")
await self._close_server_session(server_name)
session = None
if session is None:
session, session_stack = await self._create_server_session(server_name)
self.sessions[server_name] = session
self._session_exit_stacks[server_name] = session_stack
self._session_last_used[server_name] = now
return self.sessions[server_name]
async def _cleanup_expired_sessions(self):
"""回收空闲过期会话。"""
now = monotonic()
async with self._session_lock:
expired_servers = [
server_name
for server_name, last_used in self._session_last_used.items()
if now - last_used > self._SESSION_TTL_SECONDS
]
for server_name in expired_servers:
logger.info(f"回收空闲MCP会话[{server_name}]")
await self._close_server_session(server_name)
async def _session_cleanup_loop(self):
try:
while True:
await asyncio.sleep(self._SESSION_CLEANUP_INTERVAL_SECONDS)
await self._cleanup_expired_sessions()
except asyncio.CancelledError:
logger.debug("MCP会话清理任务已取消")
raise
async def _ensure_cleanup_task(self):
if self._session_cleanup_task is None or self._session_cleanup_task.done():
self._session_cleanup_task = asyncio.create_task(self._session_cleanup_loop())
return SessionContext()
async def init_tools_cache(self):
"""初始化工具列表缓存"""
if not self._cache_initialized:
await self._ensure_cleanup_task()
available_tools = []
logger.info(f"初始化工具列表缓存,需要连接{len(self.server_config)}个服务器")
for server_name in self.server_config.keys():
logger.debug(f"正在从服务器[{server_name}]获取工具列表")
session = await self._get_or_create_session(server_name)
response = await session.list_tools()
tools = response.tools
logger.debug(f"在服务器[{server_name}]中找到{len(tools)}个工具")
async with self._create_session_context(server_name) as session:
response = await session.list_tools()
tools = response.tools
logger.debug(f"在服务器[{server_name}]中找到{len(tools)}个工具")
available_tools.extend(
{
"type": "function",
"function": {
"name": f"mcp__{server_name}__{tool.name}",
"description": tool.description,
"parameters": tool.inputSchema,
},
}
for tool in tools
)
available_tools.extend(
{
"type": "function",
"function": {
"name": f"mcp__{server_name}__{tool.name}",
"description": tool.description,
"parameters": tool.inputSchema,
},
}
for tool in tools
)
# 缓存工具列表
self._tools_cache = available_tools
@ -192,7 +149,7 @@ class MCPClient:
return available_tools
async def call_tool(self, tool_name: str, tool_args: dict, group_id: int | None = None, bot_id: str | None = None):
"""按需调用工具MCP会话会在10分钟空闲后自动回收。"""
"""按需连接调用工具,调用后立即断开"""
# 检查是否是OneBot内置工具
if tool_name.startswith("ob__"):
if group_id is None or bot_id is None:
@ -211,20 +168,14 @@ class MCPClient:
real_tool_name = parts[2]
logger.info(f"按需连接到服务器[{server_name}]调用工具[{real_tool_name}]")
try:
await self._ensure_cleanup_task()
session = await self._get_or_create_session(server_name)
response = await asyncio.wait_for(session.call_tool(real_tool_name, tool_args), timeout=30)
logger.debug(f"工具[{real_tool_name}]调用完成,响应: {response}")
return response.content
except asyncio.TimeoutError:
logger.error(f"调用工具[{real_tool_name}]超时")
return f"调用工具[{real_tool_name}]超时"
except (RuntimeError, ValueError, TypeError, OSError, ConnectionError) as e:
logger.opt(exception=e).error(f"调用工具[{real_tool_name}]失败,准备重置会话")
async with self._session_lock:
await self._close_server_session(server_name)
return f"调用工具[{real_tool_name}]失败: {e!s}"
async with self._create_session_context(server_name) as session:
try:
response = await asyncio.wait_for(session.call_tool(real_tool_name, tool_args), timeout=30)
logger.debug(f"工具[{real_tool_name}]调用完成,响应: {response}")
return response.content
except asyncio.TimeoutError:
logger.error(f"调用工具[{real_tool_name}]超时")
return f"调用工具[{real_tool_name}]超时"
# 未知工具类型
return f"未知的工具类型: {tool_name}"
@ -260,19 +211,6 @@ class MCPClient:
logger.debug("正在清理MCPClient资源")
# 只清除缓存,不销毁单例
# self.clear_tools_cache() # 保留缓存,避免重复获取工具列表
if self._session_cleanup_task is not None:
self._session_cleanup_task.cancel()
try:
await self._session_cleanup_task
except asyncio.CancelledError:
pass
self._session_cleanup_task = None
async with self._session_lock:
for server_name in list(self.sessions.keys()):
await self._close_server_session(server_name)
await self.exit_stack.aclose()
# 重新初始化exit_stack以便后续使用
self.exit_stack = AsyncExitStack()