support Model Context Protocol (MCP)

This commit is contained in:
FuQuan233 2025-04-25 23:52:18 +08:00
parent dfe3b5308c
commit eb1038e09e
3 changed files with 160 additions and 12 deletions

View file

@ -28,6 +28,7 @@ from nonebot.rule import Rule
from openai import AsyncOpenAI
from .config import Config, PresetConfig
from .mcpclient import MCPClient
require("nonebot_plugin_localstore")
import nonebot_plugin_localstore as store
@ -241,6 +242,11 @@ async def process_messages(group_id: int):
- 如果你需要思考的话你应该思考尽量少以节省时间
下面是关于你性格的设定如果设定中提到让你扮演某个人或者设定中有提到名字则优先使用设定中的名字
{state.group_prompt or plugin_config.default_prompt}
"""
if preset.support_mcp:
systemPrompt += f"""
你也可以使用一些工具下面是关于这些工具的额外说明
{"\n".join([mcp_config.addtional_prompt for mcp_config in plugin_config.mcp_servers.values()])}
"""
messages: Iterable[ChatCompletionMessageParam] = [
@ -256,20 +262,74 @@ async def process_messages(group_id: int):
# 将机器人错过的消息推送给LLM
content = ",".join([format_message(ev) for ev in state.past_events])
new_messages = [{"role": "user", "content": content}]
logger.debug(
f"发送API请求 模型:{preset.model_name} 历史消息数:{len(messages)}"
)
mcp_client = MCPClient(plugin_config.mcp_servers)
await mcp_client.connect_to_servers()
available_tools = await mcp_client.get_available_tools()
client_config = {
"model": preset.model_name,
"max_tokens": preset.max_tokens,
"temperature": preset.temperature,
"timeout": 60,
}
if preset.support_mcp:
client_config["tools"] = available_tools
response = await client.chat.completions.create(
model=preset.model_name,
messages=[*messages, {"role": "user", "content": content}],
max_tokens=preset.max_tokens,
temperature=preset.temperature,
timeout=60,
**client_config,
messages=messages + new_messages,
)
if response.usage is not None:
logger.debug(f"收到API响应 使用token数{response.usage.total_tokens}")
final_message = []
message = response.choices[0].message
if message.content:
final_message.append(message.content)
# 处理响应并处理工具调用
while preset.support_mcp and message.tool_calls:
new_messages.append({
"role": "assistant",
"tool_calls": [tool_call.dict() for tool_call in message.tool_calls]
})
# 处理每个工具调用
for tool_call in message.tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
# 发送工具调用提示
await handler.send(Message(f"正在使用{mcp_client.get_friendly_name(tool_name)}"))
# 执行工具调用
result = await mcp_client.call_tool(tool_name, tool_args)
new_messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": str(result.content)
})
# 将工具调用的结果交给 LLM
response = await client.chat.completions.create(
**client_config,
messages=messages + new_messages,
)
message = response.choices[0].message
if message.content:
final_message.append(message.content)
await mcp_client.cleanup()
reply, matched_reasoning_content = pop_reasoning_content(
response.choices[0].message.content
)
@ -279,9 +339,8 @@ async def process_messages(group_id: int):
)
# 请求成功后再保存历史记录保证user和assistant穿插防止R1模型报错
state.history.append({"role": "user", "content": content})
# 添加助手回复到历史
state.history.append({"role": "assistant", "content": reply})
for message in new_messages:
state.history.append(message)
state.past_events.clear()
if state.output_reasoning_content and reasoning_content:
@ -467,10 +526,7 @@ async def load_state():
state.last_active = state_data["last_active"]
state.group_prompt = state_data["group_prompt"]
state.output_reasoning_content = state_data["output_reasoning_content"]
state.random_trigger_prob = (
state_data.get("random_trigger_prob")
or plugin_config.random_trigger_prob
)
state.random_trigger_prob = state_data.get("random_trigger_prob", plugin_config.random_trigger_prob)
group_states[int(gid)] = state

View file

@ -1,3 +1,5 @@
from typing import Optional
from pydantic import BaseModel, Field
@ -11,7 +13,18 @@ class PresetConfig(BaseModel):
max_tokens: int = Field(2048, description="最大响应token数")
temperature: float = Field(0.7, description="生成温度0-2]")
proxy: str = Field("", description="HTTP代理服务器")
support_mcp: bool = Field(False, description="是否支持MCP")
class MCPServerConfig(BaseModel):
"""MCP服务器配置"""
command: Optional[str] = Field(None, description="stdio模式下MCP命令")
args: Optional[list[str]] = Field([], description="stdio模式下MCP命令参数")
env: Optional[dict[str, str]] = Field({}, description="stdio模式下MCP命令环境变量")
url: Optional[str] = Field(None, description="sse模式下MCP服务器地址")
# 额外字段
friendly_name: str = Field("", description="MCP服务器友好名称")
addtional_prompt: str = Field("", description="额外提示词")
class ScopedConfig(BaseModel):
"""LLM Chat Plugin配置"""
@ -30,6 +43,7 @@ class ScopedConfig(BaseModel):
"你的回答应该尽量简洁、幽默、可以使用一些语气词、颜文字。你应该拒绝回答任何政治相关的问题。",
description="默认提示词",
)
mcp_servers: dict[str, MCPServerConfig] = Field({}, description="MCP服务器配置")
class Config(BaseModel):

View file

@ -0,0 +1,78 @@
from contextlib import AsyncExitStack
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from nonebot import logger
from .config import MCPServerConfig
class MCPClient:
def __init__(self, server_config: dict[str, MCPServerConfig]):
logger.info(f"正在初始化MCPClient共有{len(server_config)}个服务器配置")
self.server_config = server_config
self.sessions = {}
self.exit_stack = AsyncExitStack()
logger.debug("MCPClient初始化成功")
async def connect_to_servers(self):
logger.info(f"开始连接{len(self.server_config)}个MCP服务器")
for server_name, config in self.server_config.items():
logger.debug(f"正在连接服务器[{server_name}]")
if config.url:
sse_transport = await self.exit_stack.enter_async_context(sse_client(url=config.url))
read, write = sse_transport
self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write))
await self.sessions[server_name].initialize()
elif config.command:
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(StdioServerParameters(**config.model_dump()))
)
read, write = stdio_transport
self.sessions[server_name] = await self.exit_stack.enter_async_context(ClientSession(read, write))
await self.sessions[server_name].initialize()
else:
raise ValueError("Server config must have either url or command")
logger.info(f"已成功连接到MCP服务器[{server_name}]")
async def get_available_tools(self):
logger.info(f"正在从{len(self.sessions)}个已连接的服务器获取可用工具")
available_tools = []
for server_name, session in self.sessions.items():
logger.debug(f"正在列出服务器[{server_name}]中的工具")
response = await session.list_tools()
tools = response.tools
logger.debug(f"在服务器[{server_name}]中找到{len(tools)}个工具")
available_tools.extend(
{
"type": "function",
"function": {
"name": f"{server_name}___{tool.name}",
"description": tool.description,
"parameters": tool.inputSchema,
},
}
for tool in tools
)
return available_tools
async def call_tool(self, tool_name: str, tool_args: dict):
server_name, real_tool_name = tool_name.split("___")
logger.info(f"正在服务器[{server_name}]上调用工具[{real_tool_name}]")
session = self.sessions[server_name]
response = await session.call_tool(real_tool_name, tool_args)
logger.debug(f"工具[{real_tool_name}]调用完成,响应: {response}")
return response
def get_friendly_name(self, tool_name: str):
server_name, real_tool_name = tool_name.split("___")
return self.server_config[server_name].friendly_name + " - " + real_tool_name
async def cleanup(self):
logger.debug("正在清理MCPClient资源")
await self.exit_stack.aclose()
logger.debug("MCPClient资源清理完成")