From eb1038e09edc260ec142e6b90ee89d1a8a3038c7 Mon Sep 17 00:00:00 2001 From: FuQuan233 Date: Fri, 25 Apr 2025 23:52:18 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20support=20Model=20Context=20Protoco?= =?UTF-8?q?l=20(MCP)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot_plugin_llmchat/__init__.py | 80 ++++++++++++++++++++++++----- nonebot_plugin_llmchat/config.py | 14 +++++ nonebot_plugin_llmchat/mcpclient.py | 78 ++++++++++++++++++++++++++++ 3 files changed, 160 insertions(+), 12 deletions(-) create mode 100644 nonebot_plugin_llmchat/mcpclient.py diff --git a/nonebot_plugin_llmchat/__init__.py b/nonebot_plugin_llmchat/__init__.py index 6eb3df2..8e5ef06 100644 --- a/nonebot_plugin_llmchat/__init__.py +++ b/nonebot_plugin_llmchat/__init__.py @@ -28,6 +28,7 @@ from nonebot.rule import Rule from openai import AsyncOpenAI from .config import Config, PresetConfig +from .mcpclient import MCPClient require("nonebot_plugin_localstore") import nonebot_plugin_localstore as store @@ -241,6 +242,11 @@ async def process_messages(group_id: int): - 如果你需要思考的话,你应该思考尽量少,以节省时间。 下面是关于你性格的设定,如果设定中提到让你扮演某个人,或者设定中有提到名字,则优先使用设定中的名字。 {state.group_prompt or plugin_config.default_prompt} +""" + if preset.support_mcp: + systemPrompt += f""" +你也可以使用一些工具,下面是关于这些工具的额外说明: +{"\n".join([mcp_config.addtional_prompt for mcp_config in plugin_config.mcp_servers.values()])} """ messages: Iterable[ChatCompletionMessageParam] = [ @@ -256,20 +262,74 @@ async def process_messages(group_id: int): # 将机器人错过的消息推送给LLM content = ",".join([format_message(ev) for ev in state.past_events]) + new_messages = [{"role": "user", "content": content}] + logger.debug( f"发送API请求 模型:{preset.model_name} 历史消息数:{len(messages)}" ) + mcp_client = MCPClient(plugin_config.mcp_servers) + await mcp_client.connect_to_servers() + + available_tools = await mcp_client.get_available_tools() + + client_config = { + "model": preset.model_name, + "max_tokens": preset.max_tokens, + "temperature": preset.temperature, + "timeout": 60, + } + + if preset.support_mcp: + client_config["tools"] = available_tools + response = await client.chat.completions.create( - model=preset.model_name, - messages=[*messages, {"role": "user", "content": content}], - max_tokens=preset.max_tokens, - temperature=preset.temperature, - timeout=60, + **client_config, + messages=messages + new_messages, ) if response.usage is not None: logger.debug(f"收到API响应 使用token数:{response.usage.total_tokens}") + final_message = [] + message = response.choices[0].message + if message.content: + final_message.append(message.content) + + # 处理响应并处理工具调用 + while preset.support_mcp and message.tool_calls: + new_messages.append({ + "role": "assistant", + "tool_calls": [tool_call.dict() for tool_call in message.tool_calls] + }) + # 处理每个工具调用 + for tool_call in message.tool_calls: + tool_name = tool_call.function.name + tool_args = json.loads(tool_call.function.arguments) + + # 发送工具调用提示 + await handler.send(Message(f"正在使用{mcp_client.get_friendly_name(tool_name)}")) + + # 执行工具调用 + result = await mcp_client.call_tool(tool_name, tool_args) + + new_messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": str(result.content) + }) + + # 将工具调用的结果交给 LLM + response = await client.chat.completions.create( + **client_config, + messages=messages + new_messages, + ) + + message = response.choices[0].message + if message.content: + final_message.append(message.content) + + await mcp_client.cleanup() + reply, matched_reasoning_content = pop_reasoning_content( response.choices[0].message.content ) @@ -279,9 +339,8 @@ async def process_messages(group_id: int): ) # 请求成功后再保存历史记录,保证user和assistant穿插,防止R1模型报错 - state.history.append({"role": "user", "content": content}) - # 添加助手回复到历史 - state.history.append({"role": "assistant", "content": reply}) + for message in new_messages: + state.history.append(message) state.past_events.clear() if state.output_reasoning_content and reasoning_content: @@ -467,10 +526,7 @@ async def load_state(): state.last_active = state_data["last_active"] state.group_prompt = state_data["group_prompt"] state.output_reasoning_content = state_data["output_reasoning_content"] - state.random_trigger_prob = ( - state_data.get("random_trigger_prob") - or plugin_config.random_trigger_prob - ) + state.random_trigger_prob = state_data.get("random_trigger_prob", plugin_config.random_trigger_prob) group_states[int(gid)] = state diff --git a/nonebot_plugin_llmchat/config.py b/nonebot_plugin_llmchat/config.py index 3de419e..57737fa 100644 --- a/nonebot_plugin_llmchat/config.py +++ b/nonebot_plugin_llmchat/config.py @@ -1,3 +1,5 @@ +from typing import Optional + from pydantic import BaseModel, Field @@ -11,7 +13,18 @@ class PresetConfig(BaseModel): max_tokens: int = Field(2048, description="最大响应token数") temperature: float = Field(0.7, description="生成温度(0-2]") proxy: str = Field("", description="HTTP代理服务器") + support_mcp: bool = Field(False, description="是否支持MCP") +class MCPServerConfig(BaseModel): + """MCP服务器配置""" + command: Optional[str] = Field(None, description="stdio模式下MCP命令") + args: Optional[list[str]] = Field([], description="stdio模式下MCP命令参数") + env: Optional[dict[str, str]] = Field({}, description="stdio模式下MCP命令环境变量") + url: Optional[str] = Field(None, description="sse模式下MCP服务器地址") + + # 额外字段 + friendly_name: str = Field("", description="MCP服务器友好名称") + addtional_prompt: str = Field("", description="额外提示词") class ScopedConfig(BaseModel): """LLM Chat Plugin配置""" @@ -30,6 +43,7 @@ class ScopedConfig(BaseModel): "你的回答应该尽量简洁、幽默、可以使用一些语气词、颜文字。你应该拒绝回答任何政治相关的问题。", description="默认提示词", ) + mcp_servers: dict[str, MCPServerConfig] = Field({}, description="MCP服务器配置") class Config(BaseModel): diff --git a/nonebot_plugin_llmchat/mcpclient.py b/nonebot_plugin_llmchat/mcpclient.py new file mode 100644 index 0000000..3ea7d3d --- /dev/null +++ b/nonebot_plugin_llmchat/mcpclient.py @@ -0,0 +1,78 @@ +from contextlib import AsyncExitStack + +from mcp import ClientSession, StdioServerParameters +from mcp.client.sse import sse_client +from mcp.client.stdio import stdio_client +from nonebot import logger + +from .config import MCPServerConfig + + +class MCPClient: + def __init__(self, server_config: dict[str, MCPServerConfig]): + logger.info(f"正在初始化MCPClient,共有{len(server_config)}个服务器配置") + self.server_config = server_config + self.sessions = {} + self.exit_stack = AsyncExitStack() + logger.debug("MCPClient初始化成功") + + async def connect_to_servers(self): + logger.info(f"开始连接{len(self.server_config)}个MCP服务器") + for server_name, config in self.server_config.items(): + logger.debug(f"正在连接服务器[{server_name}]") + if config.url: + sse_transport = await self.exit_stack.enter_async_context(sse_client(url=config.url)) + read, write = sse_transport + self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write)) + await self.sessions[server_name].initialize() + elif config.command: + stdio_transport = await self.exit_stack.enter_async_context( + stdio_client(StdioServerParameters(**config.model_dump())) + ) + read, write = stdio_transport + self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write)) + await self.sessions[server_name].initialize() + else: + raise ValueError("Server config must have either url or command") + + logger.info(f"已成功连接到MCP服务器[{server_name}]") + + async def get_available_tools(self): + logger.info(f"正在从{len(self.sessions)}个已连接的服务器获取可用工具") + available_tools = [] + + for server_name, session in self.sessions.items(): + logger.debug(f"正在列出服务器[{server_name}]中的工具") + response = await session.list_tools() + tools = response.tools + logger.debug(f"在服务器[{server_name}]中找到{len(tools)}个工具") + + available_tools.extend( + { + "type": "function", + "function": { + "name": f"{server_name}___{tool.name}", + "description": tool.description, + "parameters": tool.inputSchema, + }, + } + for tool in tools + ) + return available_tools + + async def call_tool(self, tool_name: str, tool_args: dict): + server_name, real_tool_name = tool_name.split("___") + logger.info(f"正在服务器[{server_name}]上调用工具[{real_tool_name}]") + session = self.sessions[server_name] + response = await session.call_tool(real_tool_name, tool_args) + logger.debug(f"工具[{real_tool_name}]调用完成,响应: {response}") + return response + + def get_friendly_name(self, tool_name: str): + server_name, real_tool_name = tool_name.split("___") + return self.server_config[server_name].friendly_name + " - " + real_tool_name + + async def cleanup(self): + logger.debug("正在清理MCPClient资源") + await self.exit_stack.aclose() + logger.debug("MCPClient资源清理完成")